yrosseel / lavaan

an R package for structural equation modeling and more
http://lavaan.org
429 stars 98 forks source link

Regularized out-of-sample SEM #346

Closed mikedmolina closed 4 months ago

mikedmolina commented 4 months ago

Updated the existing lavPredictY function to perform regularized out-of-sample prediction and added the lavPredictY_cv function to determine the value of the lambda penalty. Updated documentation to reflect this change. DOI pending in reference.

The below script can be used to verify that this regularized method can increase predictive accuracy. Let me know if you need another empirical example. Thank you!

library(lavaan)
library(semPlot)

data(PoliticalDemocracy)
pd <- PoliticalDemocracy
colnames(PoliticalDemocracy) = c("z1", "z2", "z3", "z4", "y1", "y2", "y3", "y4", "x1", "x2", "x3")
head(PoliticalDemocracy)

model0 <- '
  # latent variable definitions
  ind60 =~ x1 + x2 + x3
  dem60 =~ z1 + z2 + z3 + z4
  dem65 =~ y1 + y2 + y3 + y4
  # regressions
  dem60 ~ ind60
  dem65 ~ ind60 + dem60
  # residual correlations
  z1 ~~ y1
  z2 ~~ z4 + y2
  z3 ~~ y3
  z4 ~~ y4
  y2 ~~ y4
'

fit <- sem(model0, data = PoliticalDemocracy, meanstructure = TRUE, warn = FALSE)
semPaths(fit, title = FALSE, intercepts = FALSE, residuals = FALSE)

# Repeated 10 fold CV for varying models

model <- '
  # latent variable definitions
  ind60 =~ x1 + x2 + x3
  dem60 =~ z1 + z2 + z3 + z4
  dem65 =~ y1 + y2 + y3 + y4
  # regressions
  dem60 ~ ind60
  dem65 ~ ind60 + dem60
  # residual correlations
  z1 ~~ y1
  z2 ~~ z4 + y2
  z3 ~~ y3
  z4 ~~ y4
  y2 ~~ y4
'

xnames = colnames(PoliticalDemocracy)[-c(5,6,7,8)]
ynames = colnames(PoliticalDemocracy)[c(5,6,7,8)]

set.seed(1234)
repeats = 100
PE = data.frame(repetition = rep(1:repeats, each = 3), 
                model = rep(1:3, repeats), 
                pe = rep(0, 3 * repeats))

folds = rep(1:10, length.out = 75)

for (r in 1:repeats){
  yhat1 = yhat2 = yhat3 = yhat4 = matrix(NA, 75, 4)
  folds = sample(folds)

  print(paste("Iteration:", r))

  for(k in 1:10){
    idx = which(folds == k)

    # Fit SEM model
    fit <- sem(model, data = PoliticalDemocracy[-idx, ], meanstructure = TRUE, warn = FALSE)

    # RO-SEM Approach
    reg.results <- lavPredictY_cv(
      fit,
      PoliticalDemocracy[-idx, ],
      xnames,
      ynames,
      n.folds = 10,
      lambda.seq = seq(from = .6, to = 2.5, by = .1)
    )
    lambda <- reg.results$lambda.min
    print(paste("lambda.min: ",lambda))
    yhat1[idx, ] = lavPredictY(fit, newdata = PoliticalDemocracy[idx, ], xnames = xnames, ynames = ynames, lambda = lambda)

    # OOS Approach
    yhat2[idx, ] = lavPredictY(fit, newdata = PoliticalDemocracy[idx, ], xnames = xnames, ynames = ynames)

    # linear regression model
    fit = lm(cbind(y1,y2,y3,y4) ~ ., data = PoliticalDemocracy[-idx, ])
    yhat3[idx, ]= predict(fit, newdata = PoliticalDemocracy[idx, ])

  }# end folds

  pe1 = sqrt(sum((PoliticalDemocracy[, ynames] - yhat1)^2)/300)
  pe2 = sqrt(sum((PoliticalDemocracy[, ynames] - yhat2)^2)/300)
  pe3 = sqrt(sum((PoliticalDemocracy[, ynames] - yhat3)^2)/300)
  PE$pe[((r-1)*3 + 1): (r*3)] = c(pe1, pe2, pe3)
} # end repetitions

library(ggplot2)
PE$model = as.factor(PE$model)
# saveRDS(PE, file = "outputs/political-dem-xval.rds")
# PE <- readRDS(file = "outputs/political-dem-xval.rds")

p <- ggplot(PE[PE$model == 1 | PE$model == 2 | PE$model == 3,], aes(x=model, y=pe, fill=factor(model))) +
      geom_boxplot(aes(group = factor(model))) + 
      geom_jitter(width = 0.05, height = 0, colour = rgb(0,0,0,.3)) + 
      xlab("Approach") + ylab("RMSEp") + 
      scale_x_discrete(labels=c("RO-SEM", "OOS", "MLR")) +
      theme(legend.position="none") +
      scale_fill_grey(start=.3,end=.7)

p

t.test(pe ~ model, data = PE[PE$model == 2 | PE$model == 1,])
yrosseel commented 4 months ago

Thanks for the PR. The checks revealed a few issues. There were some mismatches between the code and the documentation, and also the example in the documentation didn't run due to a missing comma. Can you please fix these and resubmit the PR?

mikedmolina commented 4 months ago

I was able to execute the checks workflow manually, and made another commit that should resolve all the issues. Everything should pass now. Let me know if you want me to squash down to a single commit before I resubmit the PR. Thanks, Yves!

yrosseel commented 4 months ago

Thanks. Merged now.

yrosseel commented 4 months ago

Looking at the code somewhat better: the first argument of cvLavPredictY() is 'model', and a new model is fitted explicitly (fold.fit), but the xnames and ynames require that 'object' is available. That will not work... unless you pass xnames and ynames explicitly.

I think the first argument really should be 'object' (just like in the lavPredictY() function), and to create fold.fit, we can use update(). I will change the code accordingly.

mikedmolina commented 4 months ago

Good catch. I'll do an additional code review this weekend to look for any discrepancies like this.

yrosseel commented 4 months ago

Ok. I have made the changes, and also streamlined the code somewhat. I also slightly altered the example in the man page. Can you please double-check? There is one thing I would like to change: the convention is that all (public) functions in lavaan start with 'lav'. Therefore, I would prefer to call this function lavCvPredictY() or something similar. But perhaps you have a better suggestion? (as long as it starts with lav)

mikedmolina commented 4 months ago

Everything looks great! With that naming convention in mind, I recommend lavPredictY.cv since when someone searches for lavPredictY in the help documentation, the autocomplete will show both lavPredictY and lavPredictY.cv as a hint. It is also a bit more consistent with other package naming conventions that use cross-validation for regularization like glmnet.

I've made all the function renaming changes in my fork here, in case you want to pull!

Once we get everything updated, I'll update my code example in the comments above for posterity so that everything runs with the new function name.

mikedmolina commented 4 months ago

@yrosseel We also received our DOI this morning, so I've added that to my fork!

yrosseel commented 4 months ago

Could 'lavPredictY_cv' be an option? Or lavPredictYcv? It is important to avoid functions with a 'dot', because this triggers the S3 dispatching system.

mikedmolina commented 4 months ago

_cv works! I'll update my fork accordingly.

mikedmolina commented 4 months ago

@yrosseel My fork updated with function name updated to lavPredictY_cv!