gavinsimpson / gratia

ggplot-based graphics and useful functions for GAMs fitted using the mgcv package
https://gavinsimpson.github.io/gratia/
Other
206 stars 28 forks source link

Function to obtain the overall penalty matrix from an mgcv object #238

Open fhui28 opened 1 year ago

fhui28 commented 1 year ago

Just thought this might be interesting to implement, given a penalty function has been (mostly) implemented in gratia, and, to my knowledge, mgcv does not have a straight-up capacity to compute this.

By the overall penalty/smoothing matrix, I mean the $\sum\limits_{j=1}^p \lambda_j S_j$ where the $\lambda$'s have been estimated from a gam fit. This also covers tensor and factor-smooth interactions, say, where there can be multiple $S$'s and $\lambda$'s per smooth.

Hopefully it's correct =P

#' # Get the full S matrix from GAMs. Relies on the fact that gam always move the parametric terms first
get_bigS <- function(object) {
     num_X <- ncol(model.matrix(object))

     bigS <- Matrix::Matrix(0, num_X, num_X, sparse = TRUE)
     num_smooth_terms <- length(object$smooth)
     if(num_smooth_terms == 0)
          return(bigS)

     num_Smatrices_per_smooth <- lapply(object$smooth, function(x) length(x$S)) # The sum of this should equal length(object$sp)
     sp_index <- split(1:length(object$sp), rep(1:num_smooth_terms, num_Smatrices_per_smooth)) # Because fs, te, and ti smooths have multiple S and smoothing parameters, then this tells you how many and indexes the S/sp's within each smooth term. This is very similar to extracting first.sp and last.sp from each smooth
     rm(num_Smatrices_per_smooth)
     num_smooth_cols <- sum(sapply(object$smooth, function(x) x$last.para - x$first.para + 1)) 
     num_parametric_cols <- num_X - num_smooth_cols

     subS <- lapply(1:num_smooth_terms, function(j) {
          out <- object$sp[sp_index[[j]][1]] * object$smooth[[j]]$S[[1]]
          if(length(sp_index[[j]]) > 1) { # To deal with smooths that have multiple S matrices and smoothing parameters
               for(l0 in 2:length(sp_index[[j]]))
                    out <- out + object$sp[sp_index[[j]][l0]] * object$smooth[[j]]$S[[l0]]
          }
          return(out)
     })
     subS <- Matrix::bdiag(subS)
     bigS[-(1:num_parametric_cols), -(1:num_parametric_cols)] <- subS

     return(bigS)
     }

library(gratia)
load_mgcv()
dat <- data_sim("eg4", n = 400, seed = 42)
m <- gam(y ~ s(x0, bs = "cr"), data = dat, method = "REML")
penalty(m)
get_bigS(m)
get_bigS(m)/m$sp # Matches the tidied matrix from penalty()

m <- gam(y ~ s(x0, bs = "cr") + s(x1, bs = "cr") + s(x2, by = fac, bs = "fs"), data = dat, method = "REML")
get_bigS(m)

m <- gam(y ~ te(x0, x1) + s(x2, by = fac, bs = "fs"), data = dat, method = "REML")
get_bigS(m)