Closed vincentarelbundock closed 1 year ago
From Twitter:
That was my contribution to {ggeffects}
A version of that function can be found here (I've gone ahead and added support for
marginaleffects::predictions()
):https://mattansb.github.io/MSBMisc/reference/get_data_for_grid.html
Feel free to port this!
For now, this out of scope. There's an infinite number of model diagnostics that could be implemented, and I want to keep the package simple and the number of functions limited, by focusing on the 4 quantities of interest:
This would be rather trivial to implement, so I am still open to it, but I would need someone to convince me that it really belongs in this package.
FYI, this is quite easy to implement, we can use ggeffects
’s residualize_over_grid.data.frame()
to do all the heavy lifting,
we just need a little wrapper:
residualize_over_grid.predictions <- function(grid, model, ...) {
model <- attr(grid, "model")
cond_vars <- attr(grid, "newdata_variables_datagrid")
ggeffects::residualize_over_grid(
grid = as.data.frame(grid[c(cond_vars, "estimate")]),
model = model,
pred_name = "estimate"
)
}
Now prediction grids can be used with a mix of predictions()
+ plot_predictions()
+ some ggplot
code:
library(marginaleffects)
library(ggeffects)
library(ggplot2)
library(patchwork)
# Linear models -----------------------------------------------------------
mod <- lm(mpg ~ hp + am * factor(cyl),
mtcars)
## Example 1 -----
grid <- predictions(mod, newdata = datagrid(hp = unique)) |>
ggeffects::residualize_over_grid()
(p1 <- plot_predictions(mod, condition = "hp") +
geom_point(aes(hp, estimate), data = grid))
# Compare to ggeffects:
p1_gge <- ggpredict(mod, c("hp"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
patchwork::wrap_plots(p1, p1_gge)
## Example 2 -----
# discretize a numeric
grid <- predictions(mod, newdata = datagrid(cyl = unique, hp = c(100, 200))) |>
ggeffects::residualize_over_grid()
(p2 <- plot_predictions(mod, condition = list("cyl", "hp" = c(100, 200))) +
geom_point(aes(factor(cyl), estimate, color = factor(hp)), data = grid,
position = position_dodge(0.4)))
p2_gge <- ggpredict(mod, c("cyl", "hp [100,200]"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
# Compare to ggeffects:
patchwork::wrap_plots(p2, p2_gge)
# GLM ---------------------------------------------------------------------
mod <- glm(mpg ~ hp + am * factor(cyl),
family = poisson(),
mtcars)
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.100000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 24.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 16.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 32.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 33.900000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 13.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 27.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
grid <- predictions(mod, newdata = datagrid(hp = unique)) |>
residualize_over_grid()
(p3 <- plot_predictions(mod, condition = "hp") +
geom_point(aes(hp, estimate), data = grid))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
p3_gge <- ggpredict(mod, c("hp"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
patchwork::wrap_plots(p3, p3_gge) & coord_cartesian(ylim = c(5, 50))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
Created on 2023-02-10 with reprex v2.0.2
Oh, this is cool!
Reopening to make sure I look at this in detail later.
Okay, so I re-wrote some functions to make them faster and work better with marginaleffects
. Hope this helps!
library(marginaleffects)
library(ggplot2)
.is_grid <- function (df) {
unq <- lapply(df, unique)
if (prod(sapply(unq, length)) != nrow(df)) {
return(FALSE)
}
df2 <- do.call(expand.grid, args = unq)
df2$..1 <- 1
res <- merge(df, df2, by = colnames(df), all = TRUE)
return(sum(res$..1) == sum(df2$..1))
}
match_grid <- function(grid, data) {
# 1. Duplicate grid
grid2 <- grid
# 2. Remove from duplicate fixed columns and columns not in data
is_fixed <- sapply(grid2, insight::has_single_value)
grid2 <- grid2[!is_fixed]
grid2 <- grid2[intersect(colnames(grid2), colnames(data))]
# 3. Test if grid
stopifnot(.is_grid(grid2))
# 4. For each col in grid, get the unique values
unqs <- lapply(grid2, unique)
to_char <- sapply(unqs, function(x) is.character(x) || is.factor(x) || is.logical(x))
unqs[to_char] <- lapply(unqs[to_char], as.character)
# 5. For each row in data:
for (i in seq_len(nrow(data))) {
for (j in names(unqs)) {
# For each column
if (is.factor(unqs[[j]]) || is.logical(unqs[[j]]) || is.character(unqs[[j]])) {
# If (logical, factor, char), match exactly. If not matched, return NA.
jidx <- which(as.character(data[[j]][i]) == as.character(unqs[[j]]))
if (length(jidx) == 0L) {
data[[j]][i] <- NA
next
}
} else {
# If numeric, find closest value in grid values
jidx <- which.min(abs(data[[j]][i] - unqs[[j]]))
}
data[[j]][i] <- unqs[[j]][jidx]
}
}
# 6. Add back columns from (2).
fixed_cols <- names(is_fixed)[is_fixed]
data[, fixed_cols] <- grid[1, fixed_cols]
return(data)
}
residualize_over_grid2 <- function(grid, model, ...) {
e <- residuals(model, type = "working")
inv <- insight::link_inverse(model)
mi <- insight::model_info(model)
type <- ifelse(mi$link_function == "identity", "response", "link")
data <- insight::get_data(model)
grid <- match_grid(grid, data)
grid2 <- marginaleffects::predictions(mod, newdata = grid, type = type, ...)
grid2[c("type", "std.error", "statistic", "p.value", "conf.low", "conf.high")] <- NULL
grid2[["estimate"]] <- inv(grid2[["estimate"]] + e)
grid2
}
# Linear models -----------------------------------------------------------
mod <- lm(mpg ~ hp + am * factor(cyl),
mtcars)
## Example 1 -----
grid <- datagrid(model = mod, hp = unique) |>
residualize_over_grid2(mod)
(p1 <- plot_predictions(mod, condition = "hp") +
geom_point(aes(hp, estimate), data = grid))
# Compare to ggeffects:
p1_gge <- ggeffects::ggpredict(mod, c("hp"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
patchwork::wrap_plots(p1, p1_gge)
## Example 2 -----
# discretize a numeric
grid <- datagrid(model = mod, cyl = unique, hp = c(100, 200)) |>
residualize_over_grid2(mod)
(p2 <- plot_predictions(mod, condition = list("cyl", "hp" = c(100, 200))) +
geom_point(aes(factor(cyl), estimate, color = factor(hp)), data = grid,
position = position_dodge(0.4)))
# Compare to ggeffects:
p2_gge <- ggeffects::ggpredict(mod, c("cyl", "hp [100,200]"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
patchwork::wrap_plots(p2, p2_gge)
# GLM ---------------------------------------------------------------------
mod <- glm(mpg ~ hp + am * factor(cyl),
family = poisson(),
mtcars)
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 18.100000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 24.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 22.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 16.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 17.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 10.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 14.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 32.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 33.900000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.500000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 13.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.200000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 27.300000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 30.400000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 15.800000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 19.700000
#> Warning in dpois(y, mu, log = TRUE): non-integer x = 21.400000
grid <- datagrid(model = mod, hp = unique) |>
residualize_over_grid2(mod)
(p3 <- plot_predictions(mod, condition = "hp") +
geom_point(aes(hp, estimate), data = grid))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
# Compare to ggeffects:
p3_gge <- ggeffects::ggpredict(mod, c("hp"), condition = c(cyl = 8)) |>
plot(residuals = TRUE)
patchwork::wrap_plots(p3, p3_gge) & coord_cartesian(ylim = c(5, 50))
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
# LMM ---------------------------------------------------------------------
stroop <- afex::stroop |>
subset(study == 1 & acc == 1 & pno %in% sample(pno, size = 10))
mod <- lme4::lmer(rt ~ condition + (condition | pno),
data = stroop)
grid2 <- datagrid(model = mod, condition = unique, pno = unique) |>
residualize_over_grid2(mod)
pj <- position_jitter(0.2, 0, seed = 3)
(p4 <- plot_predictions(mod, condition = "condition") +
geom_point(aes(condition, estimate), data = grid2, position = pj))
# Compare to effects
eff <- effects::predictorEffect("condition", mod, partial.residuals = TRUE)
plot(eff)
Created on 2023-02-12 with reprex v2.0.2
@mattansb I played with this code a bit and it is super cool. Thanks for taking the time.
I thought about for a while, and I think that at this point I've half-concluded that this is still out of scope.
My overarching design goal is to keep the number of arguments and functions limited, but to make sure that they work for all supported models. This is important because I want to avoid scope creep, make sure I have enough free time to polish and maintain the full feature set, and because I'd like to have as consistent an interface as possible across models.
Partial residual plots are obviously super useful and important for lm and glm models, but it's not clear to me how one would generalize the functions above to support all 70+ model types.
Also, I've intentionally kept the number of dependencies quite small to allow people to write wrapper with enhanced functionality. Daniel has already started to do that by integrating some of the delta method hypothesis testing in ggeffects
.
Anyway, just wanted to write this to make sure you knew I thought about this seriously and that I really appreciate the time you took to write it all up.
(Maybe this could be in a vignette somewhere?)
@vincentarelbundock Thanks for the update and the detailed response.
(Just BTW - the function I wrote should already work for any model supported by {marginaleffects}
and {insight}
that provides a residuals()
method.)
I understand where you are coming from - I guess I can't expect {marginaleffects}
to literally do everything (:
I think you leave me no choice but to write a "partial residual plots with {marginaleffects}
" blog post!
I think you leave me no choice but to write a "partial residual plots with
{marginaleffects}
" blog post!
Yesssss!
https://strengejacke.github.io/ggeffects/articles/introduction_partial_residuals.html