pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 988 forks source link

FR Streaming MCMC interface for big models #2843

Open fritzo opened 3 years ago

fritzo commented 3 years ago

This issue proposes a streaming architecture for MCMC on models with large memory footprint.

The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.

@fehiepsi suggested creating a new MCMC class (say StreamingMCMC) with similar interface to MCMC and still independent of kernel (using either HMC or NUTS) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC.

Along with the new StreamingMCMC class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.

Tasks (to be split into multiple PRs)

@mtsokol

@fritzo

mtsokol commented 3 years ago

Hi @fritzo!

I was searching for an issue and if this one's free I would like to try solving it.

So as I understand the main point here would be to implement StreamingMCMC that doesn't contain get_samples method and keeps in its state incrementally updated statistics (if all of them can be incrementally computed, can they?). Something like this:

class StreamingMCMC:
    def __init__(...):
        self.incremental_mean = ...
        self.incremental_variance = ...
        # and the rest of statistics that will be used by 'summary' method

    def run(self, *args, **kwargs):
        ...
        for x, chain_id in self.sampler.run(*args, **kwargs):
            num_samples += 1
            self.incremental_mean += (x - self.incremental_mean) / num_samples
            # ...and the rest of statistics
            del x
            ...

    def diagnostics(self):
        ...

    def summary(self, prob=0.9):
        # just returns computed incremental statistics

Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.

As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)

Follow up question: Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?


It's seemingly straightforward, but I've just started looking at the source code. Are there any pitfalls that I should bear in mind?

fritzo commented 3 years ago

Hi @mtsokol that sounds great and I'm happy to provide any review and guidance.

Your sketch looks good. The only difference I'd suggest would be for us to think hard about making a fully extensible interface for computing streaming statistics, so that users can easily stream other custom things. I was thinking with task 2 above to create a new module say pyro.ops.streaming with a class hierarchy of basic streamable statistics

```py from abc import ABC, abstractmethod class StreamingStatistic(ABC): """Base class for streamable statistics""" @abstractmethod def update(self, sample: Dict[str, torch.Tensor]) -> None: """Update state from a single sample.""" raise NotImplementedError @abstractmethod def merge(self, other: StreamingStatistic) -> StreamingStatistic: """Combine two aggregate statistics, e.g. from different chains.""" assert type(self) == type(other) raise NotImplementedError @abstractmethod def get(self) -> Dict[str, torch.Tensor]: """Return the aggregate statistic.""" raise NotImplementedError ``` Together with a set of basic concrete statistics (see also pyro.ops.welford for implementation but non-general interface) ```py class Count(StreamingStatistic): ... class Mean(StreamingStatistic): ... class MeanAndVariance(StreamingStatistic): ... class MeanAndCovariance(StreamingStatistic): ... class RHat(StreamingStatistic): ... ``` And maybe a restriction to a subset of names ```py class SubsetStatistic(StreamingStatistic): def __init__(self, names : Set[str], base_stat: StreamingStatistic): self.names = names self.base_stat def update(self, sample): sample = {k: v for k, v in sample.items() if k in self.names} self.base_stat.update(sample) def get(self): return self.base_stat.get() ```

I think that might be enough of an interface, but we might want more details in the __init__ methods.

Then once we have basic statistics we can make your interface generic and extensible:

```py class StreamingMCMC: def __init__(..., statistics=None): if statistics is None: statistics = [Count(), MeanAndVariance()] self._statistics = statistics def run(self, *args, **kwargs): ... for x, chain_id in self.sampler.run(*args, **kwargs): num_samples += 1 for stat in self._statistics: stat.update(x) del x ... def diagnostics(self): ... def summary(self, prob=0.9): # just returns computed incremental statistics ```

What I'd really like is to be able to define custom statistics for a particular problem, e.g. saving a list of norms

```py class ListOfNorms(StreamingStatistic): def __init__(self): self._lists = defaultdict(list) def update(self, data): for k, v in data.items(): self._lists[k].append(torch.linalg.norm(v.detach().reshape(-1)).item()) def get(self): return dict(self._lists) my_mcmc = StreamingMCMC(..., stats=[MeanAndVariance(), ListOfNorms()]) ```

WDYT?

fritzo commented 3 years ago

Addressing your earlier questions:

Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.

Correct, most existing tests should be parametrized with

@pytest.markparametrize("mcmc_cls", [MCMC, StreamingMCMC])

As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)

Almost. The main beneficiary here is large models which push against memory limits and therefore necessitate streaming rather than saving all samples in memory. And if you're pushing against memory limits, you'll want to avoid parallelizing and instead sequentially compute chains (which can itself be seen as a streaming operation). In practice yes most models that hit memory limits are run on GPU, but multicore CPU models can also be very performant.

Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?

StreamingMCMC should still support num_chains > 1 (which is valuable for determining convergence), but should compute them sequentially.

fritzo commented 3 years ago

@mtsokol would you want to work on this in parallel? Maybe you could implement the StreamingMCMC class using hand-coded statistics, I could implement a basic pyro.ops.streaming module, and over the course of a few PRs we could meet in the middle?

mtsokol commented 3 years ago

@fritzo thanks for guidance! Right now I'm looking at the current implementation and starting working on this. This abstraction with StreamingStatistic is sound to me. StreamingMCMC will only iterate and call method on passed objects implementing that interface.

Sure! I can start working on StreamingMCMC and already follow StreamingStatistic notion. When your RP is ready I will adjust my implementation.

Should I introduce some AbstractMCMC interface that existing MCMC and StreamingMCMC will implement?

fritzo commented 3 years ago

Feel free to implement an AbstractMCMC interface if you like. I defer to your design judgement here.

mtsokol commented 3 years ago

@fritzo After thinking about handling those streamed samples I wanted to ask a few more questions:

  1. So right now samples are being yield by sampler and each one is appended to the right chain list by z_flat_acc[chain_id].append(x_cloned). Then we do reshaping to get rid of the last dimension and have dict entries instead in that place (based on yielded structure). Then we perform element-wise transform (with self.transforms) (transform operation is determined by dict entry). Streaming based approach would go as follows: Again each sample is being yield by the sampler. The sample is used to construct a dict (based on yielded structure). Then that single dict is transformed (with self.transforms) and then the sample is fed to each statistic via update(self, sample: Dict[str, torch.Tensor]). (So each single sample will result in constructing a new dict, is that OK?). WDYT?

  2. Should StreamingStatistic update be chain_id-aware? Like update(self, chain_id: int, sample: Dict[str, torch.Tensor]) so that it can compute chain related diagnostics and support group_by_chain argument?

  3. Why do we need to clone: x_cloned = x.clone() when num_chains > 1?


Follow up on the first question: If such a thing makes a performance difference (but I'm just wondering - it might be irrelevant) maybe instead of streaming each sample to statistics it can work in batches instead. E.g. introduce an additional argument batch_size=100 so StreamingMCMC would wait until it aggregates 100 samples, then constructs that dict and performs transformations and feeds the whole batch to statistics. (But maybe constructing a dict for each sample and transforming each sample separately isn't really an overhead - with ready implementation I can run memory and time measurements) WDYT?

fritzo commented 3 years ago

@mtsokol answering your latest questions:

  1. tl;dr keep it simple. I do not foresee a performance hit here: it is cheap to create dicts, and StreamingMCMC will typically be used with large memory-bound models with huge tensors, where the python overhead is negligible. For this same reason I think we should avoid batching since that increases memory overhead. (In fact I suspect the bottleneck will be in pyro.ops.streaming where we may need to refactor to perform tensor operations in-place).
  2. Yes, I believe we will want to compute both per-chain and total-aggregated statistics. I have added a .merge() operation in #2856 to make this easy for you. The main motivation is to compute cross-chain statistics like r_hat.
  3. It looks like the cloning is explained earlier in the file. I would recommend keeping that logic.

https://github.com/pyro-ppl/pyro/blob/4a61ef2f9050ef81d1b0aa148d14ecc876f24a51/pyro/infer/mcmc/api.py#L389-L392

mtsokol commented 3 years ago

Hi @fritzo!

I was wondering what I can try to do next.

As Add r_hat to pyro.ops.streaming is completed I tried n_eff = ess for streaming but after short inspection of current implementation it looks undoable to me (as it requires e.g. those lags).

Apart from that I can definitely try:

Create a tutorial using StreamingMCMC on a big model

Could you suggest to me a problem with a model that would be suitable for that?

Also I can join new tutorial with your suggestion in the last bullet point in https://github.com/pyro-ppl/pyro/issues/2803#issuecomment-836644916 (showing how Predictive can be interchanged with poutine methods).

WDYT?


This would be a documentation task and I was also looking for an implementation one. Have you got something that I can try?