OHDSI / DeepPatientLevelPrediction

An R package for performing patient level prediction using deep learning in an observational database in the OMOP Common Data Model.
https://ohdsi.github.io/DeepPatientLevelPrediction
10 stars 4 forks source link

Only save performance of hyperparameter combination in training cache #88

Closed lhjohn closed 11 months ago

lhjohn commented 11 months ago

In gridCVDeep thegridSearchPrediction object grows to several GBs on larger data. https://github.com/OHDSI/DeepPatientLevelPrediction/blob/2aba7580f8410d623dfb6f012ce6fed95bdf417f/R/Estimator.R#L368 The training cache uses this object and saves it as RDS file, which can evetnually take 5+ minutes.

We should only save the performance, which currently is only AUC.

lhjohn commented 11 months ago

This won't be as easy as I thought, since we need to keep track of the best prediction as well. The best prediction will be attached as CV prediction to the result.

lhjohn commented 11 months ago

Here is a script that converts old caches to new caches. This is done in place, so a backup of the old cache maybe useful.

path <- file.path("D:/backup/paramPersistence.rds")

paramPersistence <- readRDS(path)

for (gridId in 1:length(paramPersistence$gridSearchPredictions)) {

  if (!is.null(paramPersistence$gridSearchPredictions[[gridId]])) {
    gridPerformance <- PatientLevelPrediction::computeGridPerformance(paramPersistence$gridSearchPredictions[[gridId]]$prediction, paramPersistence$gridSearchPredictions[[gridId]]$param)

    paramPersistence$gridSearchPredictions[[gridId]] <- list(
      prediction = paramPersistence$gridSearchPredictions[[gridId]]$prediction,
      param = paramPersistence$gridSearchPredictions[[gridId]]$param,
      gridPerformance = gridPerformance
    )

  }
  print(paste0("Add grid performance for index: ", gridId))
}

# remove all predictions that are not the max performance
indexOfMax <- which.max(unlist(lapply(paramPersistence$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance)))
for (i in seq_along(paramPersistence$gridSearchPredictions)) {
  if (!is.null(paramPersistence$gridSearchPredictions[[i]])) {
    if (i != indexOfMax) {
      paramPersistence$gridSearchPredictions[[i]]$prediction <- list(NULL)
    }
  }
  print(paste0("Remove prediction for index: ", i))

}

saveRDS(paramPersistence, path)