pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.69k stars 2.01k forks source link

ENH: Add checkpoints during sampling #7503

Open lucianopaz opened 1 month ago

lucianopaz commented 1 month ago

Before

No response

After

with pm.Model():
    ...
    pm.sample(..., checkpoint_file=some_path, checkpoint_freq=10)

Context for the issue:

If one has models that take very long to sample, it would be great to have a way to store the information of the steppers in a checkpoint file so that if something happens and sampling stops, we could pick up from where we left off. This is a very old feature request that is related to #292, #143 and #3661.

Those issues talk about iter_sample that works as a generator that one could simply pause and resume later. The problem with that is that there is no access to the stepper's state. I think that we need two things to get the samplers warm started:

  1. The trace that was collected so far
  2. The step method's state

Currently, most samplers and step methods provide some ways to get 1 but we never have access to 2. The current pymc samplers have a bunch of KeyboardInterrupt catches (here, here, here, and here). We could add a handling call there to also store the step method's state. nutpie has the non-blocking sampling with an abort function call when KeyboardInterrupt gets hit. We could maybe add a similar state recording thing there. blackjax has its progress bar conditional steps which we could try to mimic to get the same effect. numpyro has a similar thing going with the progress bar but it looks like it's way deeper than with blackjax.

All of this to say that I think that we need to define some kind of standard way for the samplers to provide their state information. The specific samplers would then have to conform to the standard using whatever internal things they need. For pymc samplers it would be some way to recreate the step methods (maybe using some kind of __setstate__ and __getstate__), for nutpie it would have to be some new datatype that could be sent into ruff, for blackjax it could be the kernel and random keys. I think that the important thing is to get the standard approach to which samplers should conform to, and once we have those, we could build support for checkpoints and restarting sampling from them later.

ricardoV94 commented 1 month ago

It seems to me, the important thing is we need a pm.sample that can resume from a given trace/state info. I'm not sure if pm.sample should have the extra burden of check_points, that's something the user could easily cook up in an outer loop (and we can offer as a utility ) if the functionality to resume was there?

I'm not sure how to interact with external samples, pm.sample there is basically just a gateway to the external samplers and doesn't do anything itself once those are launched.

A first step would be for our samplers to return the internal state at the last step and also allow them to resume from an externally-stored internal state?

lucianopaz commented 1 month ago

Yes, the first step is to think of some kind of standardised mechanism that samplers should expose. From my perspective, we need:

  1. something that dumps the sample state somewhere
  2. Something that can build a new sampler with the same state as the one that was dumped in 1
  3. Something that dumps the trace so far
  4. Something that can build a new trace object from what was dumped in 3
  5. Some way of using the results of 2 and 4 to resume sampling.

Once that high level interface is defined, we can try to get external samplers to also conform to it. We can also later add some easy access utility to help people orchestrate this, saving the state somewhere.

lucianopaz commented 1 month ago

@ricardoV94, while looking into this, I ran into a potential problem. PyMC step methods are intertwined with the model object, its value variables and some compiled logp functions. Serializing the step methods might be possible, but it looks like it might be very hard to do without also serializing the model, getting it to work on a sort of copy of the model. This might not be a problem at all, but I´ll have to think it through a bit longer.

ricardoV94 commented 1 month ago

The step samplers could (they kind of do already) take the model / logp function as input.

lucianopaz commented 1 month ago

My original idea for points 1 and 2 (dump and load the sampler's state) was to try and rely on the pickle standard approach. The problem with that is that if you try to dump the pickled sampler, you'd also have to dump the pickled model and whatever compiled function the step method had as an attribute. I'm almost certain that this approach will have a high memory footprint. I'm not sure if unpickling a step method will lead to problems related to having the step method point to a cloned version of the actual model that should get used. I would have loved for the step method to have weaker references (I don't mean the weakref module necessarily) to variables and compiled functions from the model, something like names of variables or something like that. That way, pickling and unpickling the step methods would safely always reference the same objects.

Since I'm a bit afraid that using the pickle approach might be bad, I'll try to add methods to already instantiated step methods that can return or set their sampling state.

ricardoV94 commented 1 month ago

Since I'm a bit afraid that using the pickle approach might be bad, I'll try to add methods to already instantiated step methods that can return or set their sampling state.

I think that's much more clean anyway

ricardoV94 commented 1 month ago

What if the trace always has access to the last state of each step sampler in a dictionary (no need to even dump, only load)?