pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.52k stars 984 forks source link

[feature request] Parallelism support for sequential plate/guide-side enumeration #3219

Open amifalk opened 1 year ago

amifalk commented 1 year ago

For mixture models with arbitrary distributions over each feature, sampling currently must be done serially, even though these operations are trivially parallelizable.

To sample priors from a hierarchical mixture model with one continuous and one binary feature, you would need to do something like

with pyro.plate('components', n_components):
   for i in pyro.plate('features', 2):
      if i == 0:     
         pyro.sample('mu', dist.Normal(0, 1))
         pyro.sample('sigma_sq', dist.InverseGamma(1, 1))
      if i == 1:
         pyro.sample('theta', dist.Beta(.1, .1))

For mixture models with large number of features, this can become very slow.

I would love to be able to use a Joblib-like syntax for loops like these, i.e.

features = [['mu', dist.Normal(0, 1)], ['sigma_sq', dist.InverseGamma(1, 1)]],  ['theta', dist.Beta(.1, .1)]]

with pyro.plate('components', n_components):
   Parallel(n_jobs=-1)(delayed(sample_priors)(features[i]) for i in pyro.plate('features', 2)) 

I have tried something like this, and something about the Joblib backend and Pyro don't play nicely together-the model doesn't converge.

In a similar vein, adding parallelism for sequential guide-side enumeration could also enable dramatic speedups. For example, when trying to fit CrossCat with SVI and two truncated stick breaking processes over views and clusters (my personal use-case), enumerating out the view assignments in the model is not possible. Enumerating the views out in the guide is much too slow if they can't be done simultaneously over multiple cores. Since each model run doesn't share information with the others it seems like this should be possible in theory.

I realize this may be difficult for reasons mentioned in #2354, but is any parallelism like this possible in Pyro?

fritzo commented 1 year ago

Hmm, I'd guess the most straightforward approach to inter-distribution cpu parallelism would be to rely on the PyTorch jit by simply using JitTrace_ELBO or similar guide.

Pros:

Cons:

pavleb commented 10 months ago

@amifalk Did you have any progress in this area? I'm facing with the same issue when dealing with model selection from a set of models with significantly different structure. I have a partial solution of using poutine.mask to mask out the log-likelihood parts in the model and guide trace from the models that are not currently selected with the discrete enumeration. Parallel enumeration can be used.

However, for complicated model structures and large set of models, the masking becomes quite complicated and prone to mistakes that can not be easily debugged.

amifalk commented 10 months ago

Sorry, no updates currently @pavleb. We ended up resolving speed issues by moving over to numpyro.