quentingronau / bridgesampling

R package for bridge sampling
32 stars 16 forks source link

Thought about memory usage #37

Open mawilson1234 opened 1 year ago

mawilson1234 commented 1 year ago

Thank you for making this package! Apologies if this is out of place, but I had a thought about the memory usage of bridge sampler when running multiple repetitions on large models. I was trying to run 100 repetitions for a fairly large model on a cluster and ran out of memory with 512 GB of RAM available.

Of course, it would be possible to run multiple calls with fewer repetitions and try to put them together myself, but I was looking at the source code to try and figure out what was causing the large memory usage, and had a thought about how it could be refactored a little bit in a way that might reduce RAM usage.

Essentially, I noticed that currently, gen_samples and q22 are pre-created for each repetition. Then, they are used to get q21 and q22 in separate loops. I was wondering if it might work to instead do a single round of gen_samples, q22, and q21 in one loop without storing all of the values. This should allow the garbage collector to reclaim the memory used in the previous loop. To allow summaries to be printed to the console when verbose == TRUE, the order of the print statements needs to be shuffled around a little bit, but otherwise it doesn't look like any results should change.

Here's a preliminary attempt at refactoring the code in this way for .bridge.sampler.normal, in case it's of interest. In brief, I pulled the creation of q11 out to its own block before the loop, and refactored all the code inside the for (i in repetitions) loops into one big loop. The summaries for q21, q22, and the iterations are printed at the end of each loop, instead of being printed serially once everything's finished, but I think that's the only observable change. But because gen_samples, q21, and q22 are defined inside the loop, the user should only need enough RAM to store one instance of each at a time, instead of that times repetitions.

.bridge.sampler.normal <- function(
  samples_4_fit, # matrix with already transformed samples for fitting the
                 # proposal (rows are samples), colnames are "trans_x" where
                 # x is the parameter name
  samples_4_iter, # matrix with already transformed samples for the
                  # iterative scheme (rows are samples), colnames are "trans_x"
                  # where x is the parameter name
  neff, # effective sample size of samples_4_iter (i.e., already transformed samples), scalar
  log_posterior,
  ...,
  data,
  lb, ub,
  transTypes, # types of transformations (unbounded/lower/upperbounded) for the different parameters (named character vector)
  param_types, # Sample space for transformations (real, circular, simplex)
  cores,
  repetitions,
  packages,
  varlist,
  envir,
  rcppFile,
  maxiter,
  silent,
  verbose,
  r0,
  tol1,
  tol2) {

  if (is.null(neff))
    neff <- nrow(samples_4_iter)

  n_post <- nrow(samples_4_iter)

  # get mean & covariance matrix and generate samples from proposal
  m <- apply(samples_4_fit, 2, mean)
  V_tmp <- cov(samples_4_fit)
  V <- as.matrix(nearPD(V_tmp)$mat) # make sure that V is positive-definite

  # evaluate log of likelihood times prior for posterior samples and generated samples
  q21 <- vector(mode = "list", length = repetitions)
  if (cores == 1) {
    q11 <- apply(.invTransform2Real(samples_4_iter, lb, ub, param_types), 1, log_posterior,
                 data = data, ...) + .logJacobian(samples_4_iter, transTypes, lb, ub)
  } else if (cores > 1) {
    if ( .Platform$OS.type == "unix") {
      split1 <- .split_matrix(matrix=.invTransform2Real(samples_4_iter, lb, ub, param_types), cores=cores)
      q11 <- parallel::mclapply(split1, FUN =
                                function(x) apply(x, 1, log_posterior, data = data, ...),
                                mc.preschedule = FALSE,
                                mc.cores = cores)
      q11 <- unlist(q11) + .logJacobian(samples_4_iter, transTypes, lb, ub)
    } else {
      cl <- parallel::makeCluster(cores, useXDR = FALSE)
      on.exit(parallel::stopCluster(cl))
      sapply(packages, function(x) parallel::clusterCall(cl = cl, "require", package = x,
                                                         character.only = TRUE))
      parallel::clusterExport(cl = cl, varlist = varlist, envir = envir)
      if ( ! is.null(rcppFile)) {
        parallel::clusterExport(cl = cl, varlist = "rcppFile", envir = parent.frame())
        parallel::clusterCall(cl = cl, "require", package = "Rcpp", character.only = TRUE)
        parallel::clusterEvalQ(cl = cl, Rcpp::sourceCpp(file = rcppFile))
      } else if (is.character(log_posterior)) {
        parallel::clusterExport(cl = cl, varlist = log_posterior, envir = envir)
      } 
      q11 <- parallel::parRapply(cl = cl, x = .invTransform2Real(samples_4_iter, lb, ub, param_types), log_posterior,
                               data = data, ...) + .logJacobian(samples_4_iter, transTypes, lb, ub)
    }

    if (any(is.infinite(q11))) {
      warning(sum(is.infinite(q11)), " of the ", length(q11)," log_prob() evaluations on the posterior draws produced -Inf/Inf.", call. = FALSE)
    }

    if (any(is.na(q11))) {
      warning(sum(is.na(q11)), " evaluation(s) of log_prob() on the posterior draws produced NA and have been replaced by -Inf.", call. = FALSE)
      q11[is.na(q11)] <- -Inf
    } 
  }

  # sample from multivariate normal distribution and evaluate for posterior samples and generated samples
  q12 <- dmvnorm(samples_4_iter, mean = m, sigma = V, log = TRUE)
  if(verbose) {
    print("summary(q11): (log_dens of posterior (i.e., with log_posterior) for posterior samples)")
    print(summary(q11))
    print("summary(q12): (log_dens of proposal (i.e., with dmvnorm) for posterior samples)")
    print(summary(q12))
    .PROPOSALS <- vector("list", repetitions)
  }

  logml <- numeric(repetitions)
  niter <- numeric(repetitions)
  for (i in seq_len(repetitions)) {
    gen_samples <- rmvnorm(n_post, mean = m, sigma = V)
    colnames(gen_samples) <- colnames(samples_4_iter)
    q22 <- dmvnorm(gen_samples, mean = m, sigma = V, log = TRUE)

    if (cores == 1) {
      q21 <- apply(.invTransform2Real(gen_samples, lb, ub, param_types), 1, log_posterior,
                   data = data, ...) + .logJacobian(gen_samples, transTypes, lb, ub)
    } else if (cores > 1) {
      if ( .Platform$OS.type == "unix") {
        split2 <- .split_matrix(matrix=.invTransform2Real(gen_samples, lb, ub, param_types), cores = cores)
        q21 <- parallel::mclapply(split2, FUN =
                                  function(x) apply(x, 1, log_posterior, data = data, ...),
                                  mc.preschedule = FALSE,
                                  mc.cores = cores)
        q21 <- unlist(q21) + .logJacobian(gen_samples, transTypes, lb, ub)
      } else {
        q21 <- parallel::parRapply(cl = cl, x = .invTransform2Real(gen_samples, lb, ub, param_types), log_posterior,
                                   data = data, ...) + .logJacobian(gen_samples, transTypes, lb, ub) 
      }
      if (any(is.infinite(q21))) {
        warning(sum(is.infinite(q21)), " of the ", length(q21)," log_prob() evaluations on the proposal draws produced -Inf/Inf.", call. = FALSE)
      }

      if (all(is.na(q21))) {
        stop("Evaluations of log_prob() on all proposal draws produced NA.\n",
             "E.g., rounded to 3 digits (use verbose = TRUE for all proposal samples):\n",
             deparse(round(
               .invTransform2Real(gen_samples, lb, ub, param_types)[1,],
               3), width.cutoff = 500L),
             call. = FALSE)
      }
      if (any(is.na(q21))) {
        warning(sum(is.na(q21)), " evaluation(s) of log_prob() on the proposal draws produced NA and have been replaced by -Inf.", call. = FALSE)
        q21[is.na(q21)] <- -Inf
      }

      tmp <- .run.iterative.scheme(q11 = q11, q12 = q12, q21 = q21, q22 = q22,
                                   r0 = r0, tol = tol1, L = NULL, method = "normal",
                                   maxiter = maxiter, silent = silent,
                                   criterion = "r", neff = neff)

      if (is.na(tmp$logml) & !is.null(tmp$r_vals)) {
        warning("logml could not be estimated within maxiter, rerunning with adjusted starting value. \nEstimate might be more variable than usual.", call. = FALSE)
        lr <- length(tmp$r_vals)
        # use geometric mean as starting value
        r0_2 <- sqrt(tmp$r_vals[[lr - 1]] * tmp$r_vals[[lr]])
        tmp <- .run.iterative.scheme(q11 = q11, q12 = q12, q21 = q21, q22 = q22,
                                     r0 = r0_2, tol = tol2, L = NULL, method = "normal",
                                     maxiter = maxiter, silent = silent,
                                     criterion = "logml", neff = neff)
        tmp$niter <- maxiter + tmp$niter
      }
      logml[i] <- tmp$logml
      niter[i] <- tmp$niter
      if (niter[i] == maxiter)
        warning("logml could not be estimated within maxiter, returning NA.", call. = FALSE)
      }

      if(verbose) {
        print("summary(q22): (log_dens of proposal (i.e., with dmvnorm) for generated samples)")
        print(summary(q22))
        print("summary(q21): (log_dens of posterior (i.e., with log_posterior) for generated samples)")
        print(summary(q21))
        # .PROPOSALS[[i]] <- .invTransform2Real(gen_samples, lb, ub, param_types)
      }
  }

  if(verbose) {
    # assign(".PROPOSALS", .PROPOSALS, pos = .GlobalEnv)
    # message("All proposal samples written to .GlobalEnv as .PROPOSALS")
  }

  if (repetitions == 1) {
    out <- list(logml = logml, niter = niter, method = "normal", q11 = q11,
                q12 = q12, q21 = q21[[1]], q22 = q22[[1]])
    class(out) <- "bridge"
  } else if (repetitions > 1) {
    out <- list(logml = logml, niter = niter, method = "normal", repetitions = repetitions)
    class(out) <- "bridge_list"
  }

  return(out)

}

Would this change make sense? I could very well be missing something important about why things have to be the way they currently are, but reducing RAM usage could make it more feasible to run multiple repetitions, which would be useful for getting better estimates.