strengejacke / ggeffects

Estimated Marginal Means and Marginal Effects from Regression Models for ggplot2
https://strengejacke.github.io/ggeffects
Other
544 stars 35 forks source link

earth (MARS) support #520

Open Deleetdk opened 4 months ago

Deleetdk commented 4 months ago
library(earth)
#> Loading required package: Formula
#> Loading required package: plotmo
#> Loading required package: plotrix
library(ggeffects)
library(tidyverse)

fit <- earth(Sepal.Length ~ ., data = iris)

fit |> summary()
#> Call: earth(formula=Sepal.Length~., data=iris)
#> 
#>                     coefficients
#> (Intercept)            5.1316165
#> Speciesvirginica      -0.4535416
#> h(Sepal.Width-2.5)     0.5436514
#> h(3.5-Petal.Length)   -0.3104219
#> h(Petal.Length-3.5)    0.8249093
#> h(Petal.Width-2.3)    -2.4300752
#> 
#> Selected 6 of 10 terms, and 4 of 5 predictors
#> Termination condition: RSq changed by less than 0.001 at 10 terms
#> Importance: Petal.Length, Sepal.Width, Speciesvirginica, Petal.Width, ...
#> Number of terms at each degree of interaction: 1 5 (additive model)
#> GCV 0.09774692    RSS 12.59045    GRSq 0.8583984    RSq 0.8767675

ggpredict(fit, terms = "Sepal.Width") |> 
  plot()
#> Error: Models of class `earth` are not yet supported.

#manually
newdata = tibble(
  Sepal.Width = seq(min(iris$Sepal.Width), max(iris$Sepal.Width), length.out = 100),
  Petal.Length = mean(iris$Petal.Length),
  Petal.Width = mean(iris$Petal.Width),
  Species = "setosa"
)

#add predictions
newdata$pred = predict(fit, newdata = newdata)

#plot
newdata |> 
  ggplot(aes(Sepal.Width, pred)) +
  geom_line() +
  geom_point(data = iris, aes(Sepal.Width, Sepal.Length)) +
  theme_minimal()

Created on 2024-05-08 with [reprex v2.1.0](https://reprex.tidyverse.org/)

earth ex

In theory, this should work with some minor modifications. One issue is that earth models cannot return standard errors or confidence intervals due to inherent theoretical limitations. However, one should still be able to plot the model predictions using ggpredict. One can do it manually.

Deleetdk commented 4 months ago

I went ahead and implemented a poor man's version of ggpredict. Maybe the code will be useful to someone else:

#some dev code for testing that our poor man's ggpredict2 works
#use iris data, fit a linear model, predict the output, plot results
#also then fit the MARS and verify it works too

library(tidyverse)
library(earth)
#> Loading required package: Formula
#> Loading required package: plotmo
#> Loading required package: plotrix
library(ggeffects)

#make data range
make_focal_data_range = function(x, length = 1000) {
  #if numeric
  if (is.numeric(x)) {
    y = seq(min(x), max(x), length.out = length)
  } else {
    y = unique(x)
  }

  y
}

#make data range for non-first focal term
make_focal_data_range_ordinal = function(x, centiles = pnorm(seq(-2, 2))) {
  #if numeric, find the 0.0228 0.1587 0.5000 0.8413 0.9772 centile values
  if (is.numeric(x)) {
    y = quantile(x, centiles)
  } else {
    y = unique(x)
  }

  y
}

#keep covariates at constant value
make_covar_data = function(x) {
  if (is.numeric(x)) {
    y = mean(x)
  } else {
    x_table = table(x)
    which_mode = which.max(x_table)
    y = x[x == names(x_table)[which_mode]][1]
  }

  y
}

#prep newdata for model predictions
prep_newdata = function(focal_terms, covar_terms, data) {
  #stop if focal terms are in covar terms
  if (any(focal_terms %in% covar_terms)) {
    stop("focal terms cannot be in covar terms", call. = F)
  }

  #prep a call for expand_grid
  call_args = list()

  #make ranges for the first term
  for (t in focal_terms) {
    #the data values depend on the order
    if (t == focal_terms[1]) {
      #if it's the first, use continuous range
      call_args[[t]] = make_focal_data_range(data[[t]])
    } else {
      #are explicit values given?
      if (str_detect(t, "\\[")) {
        #get the term by itself
        t_clean = str_remove(t, " \\[.*")
        call_args[[t_clean]] = str_match_all(t, "\\d+") %>% extract2(1) %>% str_split(",") %>% unlist() %>% as.numeric()
      } else {
        #if it's 2nd or later, split the range into ordinals
        call_args[[t]] = make_focal_data_range_ordinal(data[[t]])
      }

      #if it's 2nd or later, split the range into ordinals
      call_args[[t]] = make_focal_data_range_ordinal(data[[t]])
    }
  }

  #make cover data
  for (t in covar_terms) {
    call_args[[t]] = make_covar_data(data[[t]])
  }

  #expand data
  newdata = rlang::exec(
    expand_grid,
    !!!call_args
  )

  newdata
}

#get model predictions
get_model_preds = function(model, newdata) {
  #get classes
  model_classes = class(model)

  #get predictions from basic model types
  if ("lm" %in% model_classes) {
    newdata_preds = predict(model, newdata = as.data.frame(newdata), interval = "confidence") %>% 
      as_tibble() %>% 
      set_names(c("pred", "pred_lwr", "pred_upr"))
  } else if ("earth" %in% model_classes) {
    newdata_preds = tibble(
      pred = predict(model, newdata = as.data.frame(newdata)) %>% as.vector(),
      pred_lwr = NA,
      pred_upr = NA
    )
  } else {
    warning(str_glue("model class `{str_c(model_classes, collapse = ', ')}` may not be supported"))
    newdata_preds = predict(model, newdata = as.data.frame(newdata), interval = "confidence") %>% 
      as_tibble() %>% 
      set_names(c("pred", "pred_lwr", "pred_upr"))
  }

  newdata_preds
}

#poor man's ggpredict
ggpredict2 = function(model, focal_terms, covar_terms, data) {
  #make newdata data frame
  newdata = prep_newdata(focal_terms, covar_terms, data)

  #add predictions
  newdata_preds = get_model_preds(model, newdata)

  bind_cols(
    newdata,
    newdata_preds
  )
}

#lm fit
iris_lm = lm(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species, data = iris)

#mars fit
iris_mars = earth(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species, data = iris, degree = 2)

iris_mars %>% summary()
#> Call: earth(formula=Sepal.Length~Sepal.Width+Petal.Length+Petal.Width+...),
#>             data=iris, degree=2)
#> 
#>                     coefficients
#> (Intercept)            5.0952768
#> Speciesvirginica      -0.4561422
#> h(2.5-Sepal.Width)     0.4683702
#> h(Sepal.Width-2.5)     0.5708274
#> h(3.5-Petal.Length)   -0.3059570
#> h(Petal.Length-3.5)    0.8366556
#> h(Petal.Width-2.3)    -2.4866291
#> 
#> Selected 7 of 20 terms, and 4 of 5 predictors
#> Termination condition: Reached nk 21
#> Importance: Petal.Length, Sepal.Width, Speciesvirginica, Petal.Width, ...
#> Number of terms at each degree of interaction: 1 6 (additive model)
#> GCV 0.1041697    RSS 12.46981    GRSq 0.8490941    RSq 0.8779484

#does regular ggpredict work on earth?
ggpredict(iris_mars, terms = "Petal.Length")
#> Error: Models of class `earth` are not yet supported.
#no

#side by side
bind_rows(
  ggpredict2(
    iris_lm,
    focal_terms = "Sepal.Width",
    covar_terms = c("Petal.Width", "Petal.Length", "Species"),
    data = iris
  ) %>% mutate(model = "lm"),
  ggpredict2(
    iris_mars,
    focal_terms = "Sepal.Width",
    covar_terms = c("Petal.Width", "Petal.Length", "Species"),
    data = iris
  ) %>% mutate(model = "mars")
) %>% 
  ggplot(aes(x = Sepal.Width, y = pred, color = model)) +
  geom_line() +
  geom_ribbon(aes(ymin = pred_lwr, ymax = pred_upr), alpha = 0.2, linewidth = 0)
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf


#with a second focal term
bind_rows(
  ggpredict2(
    iris_lm,
    focal_terms = c("Sepal.Width", "Petal.Length"),
    covar_terms = c("Petal.Width", "Species"),
    data = iris
  ) %>% mutate(model = "lm"),
  ggpredict2(
    iris_mars,
    focal_terms = c("Sepal.Width", "Petal.Length"),
    covar_terms = c("Petal.Width", "Species"),
    data = iris
  ) %>% mutate(model = "mars")
) %>% 
  ggplot(aes(x = Sepal.Width, y = pred, color = factor(round(Petal.Length, 2)))) +
  geom_line() +
  geom_ribbon(aes(ymin = pred_lwr, ymax = pred_upr), alpha = 0.2, linewidth = 0) +
  facet_wrap("model")
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf

Created on 2024-05-16 with reprex v2.1.0