arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.58k stars 394 forks source link

Leave Future Out Cross Validation #2068

Open jessegrabowski opened 2 years ago

jessegrabowski commented 2 years ago

Hello,

I am interested in adapting the refitting wrappers to implement LFO-CV, as described by Bürkner, Gabry, and Vehtari (2020), with the goal of cleaning everything up, writing unit tests, and submitting a PR. I am following the R implementation the authors provide in the paper's Github repo.

I have a somewhat working prototype, which I put into a notebook here. The "exact" method, where the model is completely re-fit every time step, matches the R output very closely, so I am feeling pretty good that I am roughly on the right track. The much more important approximate method, however, seems to have a problem. The k-hat parameters seem to be much too high, and as a result the model re-fits too many times. The R implementation requires only a single re-fit for the data I present in the notebook, whereas my implementation requires more than a dozen.

I am positive I have made a mistake in my computations, and I suspect it is in the psis_log_weights function:

def psis_log_weights(ll_predict):
    # TODO: Support user-supplied name on time dim

    # Equation 10
    log_ratio = ll_predict.sum('time')
    lr_stacked = log_ratio.stack(__sample__ = ['chain', 'draw'])

    # Method "identity" gives closest match to loo::psis output in R, is it right?
    reff = az.ess(1 / np.exp(log_ratio), method="identity", relative=True).x.values

    # TODO: loo::psis doesn't need lr_stacked to be negative? See:
    #  https://github.com/paul-buerkner/LFO-CV-paper/blob/master/sim_functions.R#L231
    log_weights, k_hat = az.psislw(-lr_stacked, reff)

    return log_weights, k_hat

I am worried I'm not familiar enough with the required inputs to az.ess and az.psislw, as well as how this differ from the corresponding functions in the R loo package, to see where I am going wrong. I am hoping the community here will spot something dumb I did right away.

Thanks for any help you can give!

ahartikainen commented 2 years ago

Hi, have you tried the good old 'print all steps' in R and Python?

jessegrabowski commented 2 years ago

Haha yeah, I stepped through them both in debug mode side-by-side and took notes. That's the only way I got this far.

I have a suspicion I'm being a knucklehead by trying to follow the R code too closely. The places I'm confused are:

  1. When I compute the reff, I use the un-stacked log ratio (because az.ess expects a chain dimension), while az.pisislw expects chain and draw to be stacked into sample. Does this create inconsistencies?
  2. It's not super clear to me what the "method" argument in az.ess is doing, and there doesn't seem to be an equivalent argument in e.g. loo::relative_eff. I went with "identity" just because it made the outputs match, but not for any principled reason.
  3. Where is this 1 / np.exp(log_ratio) coming from? I just blindly copied it from R because it makes the log_weights match more closely. But in az.loo there's nothing of the sort for computing the relative ess. Here's the original R code, in both cases logratio is a (chains * samples, 1) vector made by summing the out-of-sample log-likelihoods across the time dimension.
  r_eff_lr <- loo::relative_eff(logratio, chain_id = chain_id)
  r_eff <- loo::relative_eff(1 / exp(logratio), chain_id = chain_id)
  1. For that matter, az.loo doesn't use the relative=True parameter -- is there any reason for that?
  2. az.psislw requires a negative log-likelihood, but loo::psis does not?
OriolAbril commented 2 years ago

Hi, thanks for getting this rolling @jessegrabowski! Trying to catch up.

  1. There should not be any inconsistencies due to that. ess and the different methods are tested quite thoroughly too I believe also with comparisons to R
  2. The method argument defines how the effective sample size is computed. There are different ways to go about this, such as "splitting" the 4 chains/1000 draws into 8 chains/500 draws, or using ranks to compute the ess instead of raw values (see more details for example on https://arxiv.org/abs/1903.08008, also, the different methods match different functions in the posterior package).

    After a 1 min skim of the loo::relative_eff function it looks like the computations there match the identity method (aka no split, no ranks...). This also makes sense conceptually as I assume here we don't really want to use ess as a diagnostic but instead assume we have samples from a converged fit and was an estimation of relative ess as precise as possible.

  3. IIRC, the ress parameter in psis is there to precalculate some constants and make psis more efficient. I think ArviZ currently takes the mean over all ess from the posterior variables. It is probably a better idea to use the ess directly for the quantity on which we will use psis which seems to be what is happening here.
  4. (6 can't get the number in the preview to be a 6) I don't think there is any reason for that. It might even be because loo was written before ess had a relative argument and it wasn't updated after that.
  5. (7) I haven't yet gone through all the code in LFO in detail, but from the docs both az.psislw and loo::psis behave similarly. Their expected input is the log_ratios which for "standard" loo-psis are the negative pointwise log likelihood values. If the log ratios are defined diferently here then the negative might not be necessary anymore. Ref: http://mc-stan.org/loo/reference/psis.html#arguments, quote:

    An array, matrix, or vector of importance ratios on the log scale (for PSIS-LOO these are negative log-likelihood values).

    so az.loo passes -log_lik as log_ratios to psislw, but psislw can be used to perform pareto smoothed importance sampling on any array of log ratios, that "constraint" is part of loo not of psis which is why it is done in az.loo and not in az.psislw.

Will try to go over the notebook at some point this week. Please continue asking more specific questions here if you want me to focus on something specific.

cmgoold commented 1 year ago

@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.

I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own psislw function that matched R's method exactly.

Happy to dig into this in more detail too, as it would be useful functionality.

jessegrabowski commented 1 year ago

Could be! I've let this project fall to the wayside, but I need to come back to it in the coming weeks/months, so I'm keen to collaborate on it. Could you have a look at that notebook I posted and see if anything strikes you as obviously wrong? Maybe try it your modified psislw function to see if that reduces the number of refits to match R? I think it was only 2-3, but it's been a while.

cmgoold commented 1 year ago

@jessegrabowski Yes, I can take a look!

HJA24 commented 10 months ago

@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.

I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own psislw function that matched R's method exactly.

Happy to dig into this in more detail too, as it would be useful functionality.

@cmgoold Could you share your psislw-function?

mathDR commented 10 months ago

@cmgoold I was just bumping this to see if you ever dropped your psislw function into the notebook and see if the number of refits dropped down to match R?

If not, would you be able to post your psislw function here so a member of the community could do it?

Thanks!

cmgoold commented 10 months ago

@mathDR @HJA24 Hi both. I did not get that far in the end but I can do it. Note, however, that the issue I posted about the discrepancy seems to be due to a difference in how we should be accessing the weights in R. I'll take a look today! Just don't want to lead anyone astray.

HJA24 commented 10 months ago

@cmgoold that would be much appreciated! I also don't want to lead anyone astray, but I believe the culprit is in the weights. If we look at the leave_future_out_cv-function of Jesse's notebook:

elif k_hat > tau or method == 'exact':
   # Exact ELPD
   ...
   elpd = compute_elpd(ll_predict)
else:
   # Approx ELPD
   ...
   elpd = compute_elpd(ll_predict, log_weights)

This would align with Jesse's findings; the result of the approximate-method is different and the exact method matches closely (weights are only used once).

cmgoold commented 10 months ago

@jessegrabowski @mathDR @HJA24

I embedded my functions into the notebook. Note, I ran into nan issues for the r_eff parameter, but do not have the time to go into details about why that is. So, for now, I set that variable to 1.0.

This is the final plot, which is pretty close in terms of ELPD between the two methods, but still quite far off the exact khats. I haven't checked this with the R implementation. Nonetheless, this model only refits once, which seems to match the R implementation?

Screenshot 2023-11-17 at 12 51 14

Here's my gist: https://gist.github.com/cmgoold/125eee0952c4905f3318de7ab2a11826

cmgoold commented 10 months ago

I also should point out that Burkner et al.'s implementation normalises the exponentiated weights, as I do also. This doesn't look like it's accounted for in your original script.

I'm happy to keep digging away at this and would be interested in helping out putting a PR together too, if we go down that road.

mathDR commented 7 months ago

So I would guess when this work is completed, methods for the various wrappers would be implemented?

Hey @cmgoold any PR I could review?