venpopov / chkptstanr

Checkpoint Stan R
https://venpopov.github.io/chkptstanr/
Other
6 stars 1 forks source link

What is the point of doing extra "typical" initial warmups not done when not using chkptstanr? #10

Open venpopov opened 6 months ago

venpopov commented 6 months ago

This issue is about how checkpointing works with stan's adaptation process.

In chkptstanr v0.1.1, warmup has two stages:

I don't understand why we need a separate initial warmup, and issue also asked here

Initial thoughts

To understand better what is going on, I run this while in debuging mode:

path <- setup_model_testing(dir = "context4")
formula <- brms::bf(formula = count ~ zAge + zBase)
family <- stats::poisson()
fit <- chkpt_brms(
  formula = formula,
  family = family,
  data = brms::epilepsy,
  iter_warmup = 400,
  iter_sampling = 1200,
  iter_per_chkpt = 200,
  path = path
)

After investigating the code a bit more, it is doing the following:

Run initial warmup with iter_typical (150 default)

This runs the model for 150 warmup samples, and returns:

$`1`
[1] 0.000324329 0.000301825 0.000459483
$`2`
[1] 0.00072603 0.00029929 0.00067115
[1] 0.627025 0.564246

Run checkpointing

If we setup the iter_per_chkpt to 200, as in the above example, and we have iter_warmup=400, there will be two warmup checkpoints. For the first, it runs cmdstanr with the options:

so given that the adapt_engaged is "FALSE", this doesn't change the inv_metric or step_size - the same values are true at the end of this checkpoint:

$inv_metric
$inv_metric$`1`
[1] 0.000324329 0.000301825 0.000459483
$inv_metric$`2`
[1] 0.00072603 0.00029929 0.00067115
$step_size_adapt
[1] 0.627025 0.564246

which are then passed to the next "warmup" checkpoint.

Analysis

Given this, I now understand what the "typical warmup set" is - it is the only actual warmup. There is no adaptation done during the subsequent warmup checkpoints - which are equivalent to discarded sampling iterations.

This has a major consequence:

venpopov commented 6 months ago

@jsocolar perhaps your input will be useful here, as you offered. After solving all the major bugs, I only now started looking into how the package deals with adaptation and checkpointing. Here's what I found

tldr:

jsocolar commented 6 months ago

I wondered about that! I thought that's what I was seeing, but didn't get to the bottom of it.

I think that the cleanest solution might be to reimplement the adaptation schedule in the checkpointing logic. This will be somewhat challenging, but I think totally feasible. The thorniest part will handling everything properly when one of the windowed adaptation phases needs to span multiple checkpoints (typically, the longest adaptation window is roughly a half of the total time spent in adaptation, so it will need to get broken up). But this is surmountable.

venpopov commented 6 months ago

That's exactly what I was just thinking after rereading the info bout how the warmup period is split. I think it should be totally doableq. I'll give it a go and see how it goes

jsocolar commented 6 months ago

Ping me with questions if they arise. The most delicate parts will be (1) manually computing the correct regularized mass matrix from the draws when an adaptation window is split across multiple checkpoints and (2) correctly boosting the step-size after the mass matrix updates.

Edit: the latter is easy but it's not well documented. I believe the current hard-coded behavior is to boost the previous step-size by a factor of ten. My pretty strong impression is that this isn't an ideal behavior, but it's what Stan does and it's not going to change unless somebody can conclusively show that a different strategy performs better.

venpopov commented 6 months ago

will do. Thanks for your input! Updating the package ended up being a much bigger task than I expected, but I'm really keen on using this functionality, so will keep at it.

jsocolar commented 6 months ago

Here's the canonical thread that explains the step-size behavior across adaptation windows, for reference... https://discourse.mc-stan.org/t/issue-with-dual-averaging/5995

venpopov commented 6 months ago

And as a note to self, until this is fixed:

jsocolar commented 6 months ago

I think there's an unforeseen problem lurking here that will make it very tricky to split an adaptation window across multiple checkpoints. I think the current plan is to reinitialize a new fitting routine with the inits, mass matrix, and step-size corresponding to the end of the previous run. This almost works, but doesn't fully recapitulate the behavior from Stan. The problem is that initializing the step size isn't sufficient to recreate the state of the dual averaging algorithm. I think doing that would require initializing all of the step size, the primal and dual (x-bar and s-bar I think in the Stan source code), and the iteration number.

We cannot do that, because these parameters are not exposed in Stan. I think this leaves the following options:

  1. Never allow a checkpoint to split an adaptation window (this would make things easy for us, but note that the longest adaptation window is generally approximately a quarter of the total fitting time)
  2. When a checkpoint splits an adaption window, add an extra burn-in after the split to allow the step-size to equilibrate
venpopov commented 6 months ago

Thanks for thinking about this. I've been reading more to understand how the adaptation works in detail and I was reaching a similar conclusion that checkpoints should not split the adaptation window. I'll think more about it before implement anything. We would also need to specify the init_buffer etc manually. We can set init_buffer and term buffer to 0 during the slow adaptation checkpoints. But I don't see a way to control the slow adaptation windows. It seems like it would always split a checkpoint into the expanding intervals (about to test that). Any idea how to control that?

jsocolar commented 6 months ago

You can control the length of the first window with the window argument. So for example if you've just finished a 25-iteration window and want to do 150 iterations, with a 50-iteration window and a 100-iteration window, I think you would pass zero for init_buffer and term_buffer, and 50 for window.

And you'd manually boost the step-size 10x from its last value.

jsocolar commented 6 months ago

If the total windowed phase requested is not expressible as $\sum{i \times w}$ where $w$ is the window argument, then I'm pretty sure that window is respected, and the final window is made extra-long to accommodate any and all "leftover" iterations belonging to the windowed phase.

jsocolar commented 6 months ago

I just thought of a different flavor of design that might be worth considering (disadvantage: it wouldn't work for rstan, just cmdstan). If we just run the sampler with save_warmup = TRUE, then rather than stopping and restarting the sampler for checkpointing, we could just write a script that can look at the csv (or csvs plural if the sampler has already stopped and started at least once), superimpose the known warmup schedule, find the last finished window, compute the regularized inverse metric from that window, and re-start sampling there. That is, the stan CSVs already provide all the information we need to restart the sampler from the last completed window (or from the last completed iteration if we're in the sampling phase).

venpopov commented 6 months ago

You know, I was considering the same but for a different reason - speed. Right now the way checkpoint was implemented in the original package there is considerable overhead in completing and the starting a checkpoint. I was thinking of doing what you suggested just for efficiency, but you have a great point that it might solve the adaptation issue as well.

As for rstan, the package did not work with rstan before either :) (although it was possible to code in a similar way). But I'm not to worried about that. I used rstan exclusively until about 2 months ago, and now I don't want to go back. I get way less crashes, and for some of my models much faster sampling (although I have not been able to figure out why or when).

jsocolar commented 6 months ago

Here's a function to return the adaptation schedule. If you'd like I can submit a PR on the feature/refactor branch?

#' @param iter_warmup number of requested warmup iterations
#' @param iter_sampling
#' @param init_buffer 
#' @param term_buffer 
#' @param window
#' @example get_adaptation_schedule(1000, 1000, 75, 50, 25)

get_adaptation_schedule <- function(
    iter_warmup, 
    iter_sampling, 
    init_buffer, 
    term_buffer, 
    window) {
  # Make sure that we can fit the requested buffers
  assertthat::assert_that(iter_warmup >= init_buffer + term_buffer)

  # determine how many windows we have space for
  window_iters <- iter_warmup - (init_buffer + term_buffer)
  n_window <- floor(log(window_iters/window + 1, 2))

  # if n_window is zero, add any leftover iterations to the init buffer
  if(n_window == 0) {
    init_buffer <- init_buffer + window_iters
    window_iters <- 0
  }

  # initialize output
  out <- data.frame(phase = character(), length = integer())

  # add the init buffer if requested
  if(init_buffer > 0){
    out <- rbind(out, data.frame(phase = "init_buffer", length = init_buffer))
  }

  # add the windowed phase if requested
  if(window_iters > 0){
    window_lengths <- window * 2^c(0:(n_window-1))
    window_lengths[n_window] <- window_lengths[n_window] + (window_iters - sum(window_lengths))
    out_window <- data.frame(
      phase = "window",
      length = window_lengths
    )
    out <- rbind(out, out_window)
  }

  # add the term buffer if requested
  if(term_buffer > 0) {
    out <- 
      rbind(
        out, 
        data.frame(
          phase = "term_buffer", 
          length = term_buffer
          )
        )
  }

  # add the sampling phase if requested
  if(iter_sampling > 0) {
    out <- 
      rbind(
        out, 
        data.frame(
          phase = "sampling", 
          length = iter_sampling
        )
      )
  }

  # put in starting iterations
  n_phase <- nrow(out)
  if(n_phase == 1){
    out$iter <- 1
  } else {
    out$iter <- 1 + c(0, cumsum(out$length[1:(n_phase - 1)]))
  }

  # return
  out
}
venpopov commented 6 months ago

Thanks! Sure, do a pull request there and I'll see how to integrate it with the ongoing changes I'm making. It is isolated so that should be easy. I'll then push my commits once I have a working refactored version

jsocolar commented 6 months ago

If you want to spell out some development tasks that are more specific to this refactor than the development roadmap, I might be able to pick one or two up in the next few days. FWIW, here's what I see as the critical tasks necessary to get to a prototype here:

  1. A function compute_metric to take a set of unconstrained draws and compute the regularized inverse metric, for both diagonal and dense cases.
  2. A function unconstrained_parameters to take a model specification and model draws, and return the unconstrained draws for parameters (but not transformed parameters or generated quantities) in the order of the columns of the mass matrix.
  3. A function combined_csvs to take multiple CSVs (resulting from one or more restarts), read them in, and combine them into a single draws object.
  4. A function to that uses the above functions, in conjunction with get_adaptation_schedule to extract the correct arguments for iter_warmup, iter_sampling, init_buffer, window, term_buffer, sampling, init, step_size, inv_metric when restarting sampling.
  5. A function that actually restarts sampling.

I would be very happy to work on 1 and/or 4, and happy enough to work on any of the others.

venpopov commented 6 months ago

That's a good list. I think 1 is a good place to start. Thanks a lot for helping with this! I will open a new issue for tracking the refactoring, describe my current thinking of how everything should be organized and include these in a bigger list of things that need implementing for tracking.

3 already has a precursor that I built for v0. 2.0 - https://github.com/venpopov/chkptstanr/blob/master/R/make_brmsfit.R but will likely need adjustment. For now my suggestion is to work on 1 and I will try to get the design document ready to coordinate how everything will fit together. Given the complete rewrite, my plan is actually to start clean and not worry about backwards compatibility.

venpopov commented 5 months ago

So a quick update. The refactoring is basically done. But I ended up rewriting everything from scratch. I also have not been able to get Donald to transfer the cran maintenance status despite his initial response that I can take over if I want. So I decided to start fresh with a new package in a separate repo. I have to add some more tests and documentation, but I'll share it in a few days. In that version the checkpointing is no longer explicit. Rather it is as we discussed above, upon interruption the fit is reconstructed from combined csv draw files, and you can increase the number of iterations and continue sampling at any point even if the fit completed successfully. It makes for a much more streamlined process. Right now I have made it work after warmup, but the framework is there to accommodate the new functions we discussed for the adaptation. Will share the link soon