pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.14k stars 235 forks source link

[FR] Utility for joint distributions #1810

Closed Qazalbash closed 3 months ago

Qazalbash commented 3 months ago

Description

I request for a utility which allows to create a joint distribution from independent distributions like this:

$$p_{X_1,X_2,\cdots,X_n}(x_1,x_2,\cdots,xn)=p{X_1}(x1)p{X_2}(x2)\cdots p{X_n}(x_n)$$

Most importantly we should be able to use the log_prob method along with sample.

My implementation

```python # Copyright 2023 The GWKokab Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from jax import lax, numpy as jnp from jaxtyping import PRNGKeyArray from numpyro import distributions as dist from numpyro.util import is_prng_key class JointDistribution(dist.Distribution): r"""Joint distribution of multiple marginal distributions.""" def __init__(self, *marginal_distributions: dist.Distribution) -> None: r""" :param marginal_distributions: A sequence of marginal distributions. """ self.marginal_distributions = marginal_distributions self.shaped_values = tuple() batch_shape = lax.broadcast_shapes(*tuple(d.batch_shape for d in self.marginal_distributions)) k = 0 for d in self.marginal_distributions: if d.event_shape: self.shaped_values += (slice(k, k + d.event_shape[0]),) k += d.event_shape[0] else: self.shaped_values += (k,) k += 1 super(JointDistribution, self).__init__( batch_shape=batch_shape, event_shape=(k,), validate_args=True, ) def log_prob(self, value): log_probs = jax.tree.map( lambda d, v: d.log_prob(value[..., v]), self.marginal_distributions, self.shaped_values, is_leaf=lambda x: isinstance(x, dist.Distribution), ) log_probs = jnp.sum(jnp.asarray(log_probs).T, axis=-1) return log_probs def sample(self, key: PRNGKeyArray, sample_shape: tuple[int, ...] = ()): assert is_prng_key(key) keys = tuple(jax.random.split(key, len(self.marginal_distributions))) samples = jax.tree.map( lambda d, k: d.sample(k, sample_shape).reshape(*sample_shape, -1), self.marginal_distributions, keys, is_leaf=lambda x: isinstance(x, dist.Distribution), ) samples = jnp.concatenate(samples, axis=-1) return samples ```

fehiepsi commented 3 months ago

Hi @Qazalbash, you might want to use tfd.JointDistribution. It will likely work under numpyro inference methods.

Qazalbash commented 3 months ago

@fehiepsi, I'm working in a pure JAX environment and want to avoid mixing tensors with JAX arrays. Any suggestions on how to achieve this?

fehiepsi commented 3 months ago

I guess it's fairly complicated. You'll need to define a joint support and the corresponding transform properly. Then I guess your implementation could work with numpyro inference methods.

Qazalbash commented 3 months ago

I completely understand your reservations, but I truly believe this utility would be highly valuable! I'd be more than happy to support in any way I can.

tillahoffmann commented 3 months ago

Just out of interest, what's the advantage of having a factorized joint distribution over using multiple sample statements and then stacking the samples?

fehiepsi commented 3 months ago

I'm also curious. I guess it is useful for a sort of autoguide style, where neural networks will approximate the joint posterior. Even in that case, it's still better to explicitly factorize the variational distribution into separate ones.

Qazalbash commented 3 months ago

I've been working on a project that involves sampling from and calculating the log_prob of joint distributions (see code above). My goal is to create a utility class that streamlines this process, avoiding the need for manual loops and stacking, and boost the performance by being purely in JAX. This is my approach, but I'm open to feedback and corrections - perhaps I'm mistaken!

tillahoffmann commented 3 months ago

My gut feeling is that there likely won't be a large performance improvement unless the list of marginal distributions is very large, e.g., because there is still a loop to infer the shapes? Have you already run benchmarks for the two alternatives?

If the list is very large, it might make sense to combine distributions that have the same class, e.g., batch all the Normal distributions in marginal_distributions, batch all the Gamma distributions, etc.?

Qazalbash commented 3 months ago

I haven't benchmarked it yet! I will try it as soon as possible.

danielward27 commented 3 months ago

Perhaps I'm missing something, but the annoyance I have with relying on explicitly factorising the guide is that 1) conditioning becomes a lot of hassle, 2) it presumably forces computation to be performed sequentially (compared to autoregressive models that e.g. use masking). To me, it would be great if I could easily specify a joint distribution in the guide corresponding to a set of latents from the model.

Is there any way to "hack" this e.g. using factor/Delta statements?

tillahoffmann commented 3 months ago

Do you have an example of what this might look like in the guide, e.g., as pseudocode?

danielward27 commented 3 months ago

Edit: I think what I want is already supported - are there any issues with this sort of approach?

import jax.random as jr
from numpyro import sample, param, factor
from numpyro.distributions import Normal, Delta, constraints, MultivariateNormal
import numpyro
def model(obs=None):
    a = sample("a", Normal())
    b = sample("b", Normal())
    sample("c", Normal(a+b), obs=obs)

def guide(*args, **kwargs):
    loc = param("loc", jnp.zeros(2))
    cholesky = param("cholesky", jnp.eye(2), contraint=constraints.lower_cholesky)
    joint_dist = MultivariateNormal(loc, cholesky@cholesky.T)
    joint_sample = sample("joint", joint_dist)

    a, b = joint_sample
    sample("a", Delta(a))  
    sample("b", Delta(b))
tillahoffmann commented 3 months ago

Yes, your example implementation is actually very close to the automatic guides here.

https://github.com/pyro-ppl/numpyro/blob/5af9ebda72bd7aeb08c61e4248ecd0d982473224/numpyro/infer/autoguide.py#L719-L722

You can slightly shorten your code, if you like, by using a deterministic site like so.

a, b = joint_sample
numpyro.deterministic("a", a)
numpyro.deterministic("b", b)

Doesn't save you many key strokes, but it may be more explicit?

fehiepsi commented 3 months ago

It would be nice to include

joint_sample = sample("joint", joint_dist, infer={'is_auxiliary': True})

Some objectives have an explicit check that sample sites in guide need to be available in the model.

Closed because the issue seems to have been resolved.