gavinsimpson / gratia

ggplot-based graphics and useful functions for GAMs fitted using the mgcv package
https://gavinsimpson.github.io/gratia/
Other
206 stars 28 forks source link

Extracting the coefficients of a GAM model #318

Closed onetimestats2024 closed 4 weeks ago

onetimestats2024 commented 1 month ago

Hello!

I simulated some data in R and tried to fit a GAM regression model. The model ran successfully and here are the results:

library(mgcv)
library(ggplot2)
library(gridExtra)

set.seed(123)

n <- 200  
x1 <- seq(0, 10, length.out = n)  
x2 <- seq(0, 5, length.out = n)   
x3 <- runif(n, 0, 15)             

y <- sin(x1) + log(x2 + 1) + 0.1 * x3 + rnorm(n, sd = 0.2)

data <- data.frame(x1 = x1, x2 = x2, x3 = x3, y = y)

gam_model <- gam(y ~ s(x1) + s(x2) + s(x3), data = data)

pred_x1 <- predict(gam_model, newdata = data.frame(x1 = x1, x2 = mean(x2), x3 = mean(x3)), se.fit = TRUE)
pred_x2 <- predict(gam_model, newdata = data.frame(x1 = mean(x1), x2 = x2, x3 = mean(x3)), se.fit = TRUE)
pred_x3 <- predict(gam_model, newdata = data.frame(x1 = mean(x1), x2 = mean(x2), x3 = x3), se.fit = TRUE)

plot_data_x1 <- data.frame(x = x1, y_pred = pred_x1$fit, se = pred_x1$se.fit)
plot_data_x2 <- data.frame(x = x2, y_pred = pred_x2$fit, se = pred_x2$se.fit)
plot_data_x3 <- data.frame(x = x3, y_pred = pred_x3$fit, se = pred_x3$se.fit)

p1 <- ggplot(plot_data_x1, aes(x = x, y = y_pred)) +
geom_line(color = "red", size = 1) +
geom_ribbon(aes(ymin = y_pred - 2 * se, ymax = y_pred + 2 * se), alpha = 0.2) +
labs(title = "Smooth curve for x1", y = "Predicted y") +
theme_minimal()

p2 <- ggplot(plot_data_x2, aes(x = x, y = y_pred)) +
geom_line(color = "blue", size = 1) +
geom_ribbon(aes(ymin = y_pred - 2 * se, ymax = y_pred + 2 * se), alpha = 0.2) +
labs(title = "Smooth curve for x2", y = "Predicted y") +
theme_minimal()

p3 <- ggplot(plot_data_x3, aes(x = x, y = y_pred)) +
geom_line(color = "green", size = 1) +
geom_ribbon(aes(ymin = y_pred - 2 * se, ymax = y_pred + 2 * se), alpha = 0.2) +
labs(title = "Smooth curve for x3", y = "Predicted y") +
theme_minimal()

grid.arrange(p1, p2, p3, ncol = 3)

The visuals look like this:

image

I am wondering - is it possible to extract the coefficients of this regression model in R? For example, I want to get access to the (approximate) mathematical function that is used to draw each of these colored curves.

Can this be done in R? Ideally, I would like to be able to extract all the coefficients from the output of this GAM model after running the R code and recreate these plots myself (i.e. the colored lines).

(the reason I am asking is because later I might need to take a numerical derivative of these functions)