neural-structured-additive-learning / safareg

GNU General Public License v3.0
0 stars 0 forks source link

Random intercepts with `fac_processor` #3

Closed vhmedina closed 8 months ago

vhmedina commented 9 months ago

Hi!

When comparing the estimations of random intercepts in a simple simulation between mgcv and deepregression, I obtain different results. The results with deepregression tend to be more shrank than in mgcv.

Am I missing something? Please find a reproducible example below. Many thanks!

library(tidyverse)
library(mgcv)
library(deepregression)
library(safareg)

# N number of subjects
# Tobs number of measurements per subjects
simulate_gaussian_mixed <- function(N, Tobs, seed = 1) {
  set.seed(seed)
  df <- data.frame(id = rep(seq_len(N), 
                            each = Tobs),
                   time = gl(Tobs, 1, 
                             length = N * Tobs,
                             labels = paste0("T",1:Tobs)))
  # design matrix for the fixed effects
  X <- model.matrix(~ time, data = df)
  betas <- runif(Tobs,-1,1) # fixed effects coefficients
  sigma_b <- 0.5 # standard deviation of random intercepts
  sigma_y <- 0.1 # sd of the gaussian distr
  # simulate random effects
  b <- rnorm(N, sd = sigma_b)
  # linear predictor
  eta <- c(X %*% betas + b[df$id])
  # simulate response
  df$y <- rnorm(N * Tobs, mean = eta, sd = sigma_y)
  df$id <- factor(df$id)
  return(list(df=df, betas=betas, b=b))
}
N <- 500
Tobs <- 5
sim <- simulate_gaussian_mixed(N, Tobs, seed=123)
df <- sim$df
betas <- sim$betas
b <- sim$b
# estimation with mgcv
mod_mgcv <- bam(y ~ time + s(id, bs='re'), 
                data = df, 
                method = 'REML')
# estimation with deepregression
mod_dr <- deepregression(y=df$y,
               list_of_formulas = list(loc =~ time + main(id), scale=~ 1),
               additional_processors = list(main = fac_processor),
               data=df)
mod_dr %>% fit(epochs = 1000, early_stopping = TRUE)

# check Random Intercepts
cbind(unique(cbind(id=df$id,
             pred_re_mgcv=as.numeric(predict(mod_mgcv, type = "terms")[,2]))),
      pred_re_dr = as.numeric(get_partial_effect(mod_dr, 
                              names = names(coef(mod_dr))[2],
                              newdata = data.frame(id=factor(seq(N))))),
      true_re = b) %>% 
  data.frame() %>% 
  pivot_longer(-c(id,true_re)) %>% 
  ggplot(aes(x = true_re, y=value, col=name))+geom_point()+
  geom_abline(slope=1)

image

sessionInfo() is the following

R version 4.3.1 (2023-06-16)
Platform: x86_64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.7.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] safareg_0.1               deepregression_1.0.0      keras_2.13.0             
 [4] tfprobability_0.15.1.9000 tensorflow_2.14.0         mgcv_1.9-0               
 [7] nlme_3.1-163              lubridate_1.9.3           forcats_1.0.0            
[10] stringr_1.5.1             dplyr_1.1.4               purrr_1.0.2              
[13] readr_2.1.4               tidyr_1.3.0               tibble_3.2.1             
[16] ggplot2_3.4.4             tidyverse_2.0.0          

loaded via a namespace (and not attached):
 [1] utf8_1.2.4        generics_0.1.3    stringi_1.8.2     lattice_0.22-5    hms_1.1.3        
 [6] magrittr_2.0.3    grid_4.3.1        timechange_0.2.0  jsonlite_1.8.7    Matrix_1.6-3     
[11] whisker_0.4.1     tfruns_1.5.1      fansi_1.0.5       scales_1.2.1      cli_3.6.1        
[16] rlang_1.1.2       munsell_0.5.0     splines_4.3.1     base64enc_0.1-3   withr_2.5.2      
[21] tools_4.3.1       tzdb_0.4.0        colorspace_2.1-0  zeallot_0.1.0     reticulate_1.34.0
[26] vctrs_0.6.4       R6_2.5.1          png_0.1-8         lifecycle_1.0.4   pkgconfig_2.0.3  
[31] pillar_1.9.0      gtable_0.3.4      glue_1.6.2        Rcpp_1.0.11       tidyselect_1.2.0 
[36] rstudioapi_0.15.0 compiler_4.3.1 
davidruegamer commented 9 months ago

Hi @vhmedina ,

Thanks for your question / issue. There are two reasons you see this additional shrinkage in the random effects when using deepregression:

  1. Optimization with stochastic gradient descent and variants like Adam (the default in this case) will implicitly bias estimates and so does the early stopping. So even if the same objective is minimized as in mgcv, the optimization will make a difference and in this case induce additional shrinkage.
  2. The fac_processor defines a ridge-penalized linear (categorical) effect. While this will induce the same "penalty" as a normal distribution assumption for random effects, the amount of penalization is fixed. We treat it as a hyperparameter (la) and you could, e.g., grid-search over different values for this hyperparameter (in contrast, mgcv estimates this penalty parameter implicitly by estimating the random effect variances).

Regarding point 2: We will hopefully soon have an update that allows to automatically learn smoothing parameters and random effects variances.

Let me know if you have any further questions.

HTH, David

vhmedina commented 8 months ago

Many thanks for the reply, @davidruegamer. This is really helpful. I played around a bit with grid-search over different values for lambda. The closest combination I got was when la=0 (no regularization). It improved slightly when changing the optimizer to SGD with larger momentum, but I couldn't reach the mgcv results. Thanks again!

davidruegamer commented 8 months ago

Interesting. Thanks for sharing your results! I think it might not (always) be possible to reach the mgcv result due to the implicit regularization. Switching from Adam (i.e., parameter-specific learning rates) to SGD might in fact mitigate the problem as the implicit shrinkage is less specific, but chances are still small that one ends up with the same amount of penalization.