Open fritzo opened 3 years ago
Hi @fritzo!
I was searching for an issue and if this one's free I would like to try solving it.
So as I understand the main point here would be to implement StreamingMCMC
that doesn't contain get_samples
method and keeps in its state incrementally updated statistics (if all of them can be incrementally computed, can they?). Something like this:
class StreamingMCMC:
def __init__(...):
self.incremental_mean = ...
self.incremental_variance = ...
# and the rest of statistics that will be used by 'summary' method
def run(self, *args, **kwargs):
...
for x, chain_id in self.sampler.run(*args, **kwargs):
num_samples += 1
self.incremental_mean += (x - self.incremental_mean) / num_samples
# ...and the rest of statistics
del x
...
def diagnostics(self):
...
def summary(self, prob=0.9):
# just returns computed incremental statistics
Also in test_mcmc_api.py
additional tests should run both types of MCMC
classes and compare final statistics.
As described, StreamingMCMC
shouldn't support multiprocessing manually via _Worker
because here CUDA
, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)
Follow up question: Should StreamingMCMC
have num_chains
argument and for num_chains>1
just compute them sequentially or omit this argument?
It's seemingly straightforward, but I've just started looking at the source code. Are there any pitfalls that I should bear in mind?
Hi @mtsokol that sounds great and I'm happy to provide any review and guidance.
Your sketch looks good. The only difference I'd suggest would be for us to think hard about making a fully extensible interface for computing streaming statistics, so that users can easily stream other custom things. I was thinking with task 2 above to create a new module say pyro.ops.streaming
with a class hierarchy of basic streamable statistics
I think that might be enough of an interface, but we might want more details in the __init__
methods.
Then once we have basic statistics we can make your interface generic and extensible:
What I'd really like is to be able to define custom statistics for a particular problem, e.g. saving a list of norms
WDYT?
Addressing your earlier questions:
Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.
Correct, most existing tests should be parametrized with
@pytest.markparametrize("mcmc_cls", [MCMC, StreamingMCMC])
As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)
Almost. The main beneficiary here is large models which push against memory limits and therefore necessitate streaming rather than saving all samples in memory. And if you're pushing against memory limits, you'll want to avoid parallelizing and instead sequentially compute chains (which can itself be seen as a streaming operation). In practice yes most models that hit memory limits are run on GPU, but multicore CPU models can also be very performant.
Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?
StreamingMCMC
should still support num_chains > 1
(which is valuable for determining convergence), but should compute them sequentially.
@mtsokol would you want to work on this in parallel? Maybe you could implement the StreamingMCMC
class using hand-coded statistics, I could implement a basic pyro.ops.streaming
module, and over the course of a few PRs we could meet in the middle?
@fritzo thanks for guidance! Right now I'm looking at the current implementation and starting working on this.
This abstraction with StreamingStatistic
is sound to me. StreamingMCMC
will only iterate and call method on passed objects implementing that interface.
Sure! I can start working on StreamingMCMC
and already follow StreamingStatistic
notion. When your RP is ready I will adjust my implementation.
Should I introduce some AbstractMCMC
interface that existing MCMC
and StreamingMCMC
will implement?
Feel free to implement an AbstractMCMC
interface if you like. I defer to your design judgement here.
@fritzo After thinking about handling those streamed samples I wanted to ask a few more questions:
So right now samples are being yield
by sampler and each one is appended to the right chain list by z_flat_acc[chain_id].append(x_cloned)
. Then we do reshaping to get rid of the last dimension and have dict entries instead in that place (based on yielded structure). Then we perform element-wise transform (with self.transforms
) (transform operation is determined by dict entry).
Streaming based approach would go as follows: Again each sample is being yield
by the sampler. The sample is used to construct a dict (based on yielded structure). Then that single dict is transformed (with self.transforms
) and then the sample is fed to each statistic via update(self, sample: Dict[str, torch.Tensor])
. (So each single sample will result in constructing a new dict, is that OK?). WDYT?
Should StreamingStatistic
update be chain_id
-aware? Like update(self, chain_id: int, sample: Dict[str, torch.Tensor])
so that it can compute chain related diagnostics and support group_by_chain
argument?
Why do we need to clone: x_cloned = x.clone()
when num_chains > 1
?
Follow up on the first question: If such a thing makes a performance difference (but I'm just wondering - it might be irrelevant) maybe instead of streaming each sample to statistics it can work in batches instead. E.g. introduce an additional argument batch_size=100
so StreamingMCMC
would wait until it aggregates 100 samples, then constructs that dict and performs transformations and feeds the whole batch to statistics. (But maybe constructing a dict for each sample and transforming each sample separately isn't really an overhead - with ready implementation I can run memory and time measurements) WDYT?
@mtsokol answering your latest questions:
StreamingMCMC
will typically be used with large memory-bound models with huge tensors, where the python overhead is negligible. For this same reason I think we should avoid batching since that increases memory overhead. (In fact I suspect the bottleneck will be in pyro.ops.streaming where we may need to refactor to perform tensor operations in-place)..merge()
operation in #2856 to make this easy for you. The main motivation is to compute cross-chain statistics like r_hat.Hi @fritzo!
I was wondering what I can try to do next.
As Add r_hat to pyro.ops.streaming
is completed I tried n_eff = ess
for streaming but after short inspection of current implementation it looks undoable to me (as it requires e.g. those lags).
Apart from that I can definitely try:
Create a tutorial using StreamingMCMC on a big model
Could you suggest to me a problem with a model that would be suitable for that?
Also I can join new tutorial with your suggestion in the last bullet point in https://github.com/pyro-ppl/pyro/issues/2803#issuecomment-836644916 (showing how Predictive
can be interchanged with poutine methods).
WDYT?
This would be a documentation task and I was also looking for an implementation one. Have you got something that I can try?
This issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say
StreamingMCMC
) with similar interface toMCMC
and still independent of kernel (using eitherHMC
orNUTS
) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support inStreamingMCMC
.Along with the new
StreamingMCMC
class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.Tasks (to be split into multiple PRs)
@mtsokol
StreamingMCMC
class with interface identical to MCMC (except disallowing parallel chains).MCMC
to parametrize over bothMCMC
andStreamingMCMC
StreamingMCMC
andMCMC
perform identical computations, up to numerical precisionStreamingMCMC
on a big model@fritzo
r_hat
to pyro.ops.streamingn_eff = ess
to pyro.ops.streaming