Closed Qazalbash closed 3 months ago
Hi @Qazalbash, you might want to use tfd.JointDistribution. It will likely work under numpyro inference methods.
@fehiepsi, I'm working in a pure JAX environment and want to avoid mixing tensors with JAX arrays. Any suggestions on how to achieve this?
I guess it's fairly complicated. You'll need to define a joint support and the corresponding transform properly. Then I guess your implementation could work with numpyro inference methods.
I completely understand your reservations, but I truly believe this utility would be highly valuable! I'd be more than happy to support in any way I can.
Just out of interest, what's the advantage of having a factorized joint distribution over using multiple sample statements and then stacking the samples?
I'm also curious. I guess it is useful for a sort of autoguide style, where neural networks will approximate the joint posterior. Even in that case, it's still better to explicitly factorize the variational distribution into separate ones.
I've been working on a project that involves sampling from and calculating the log_prob
of joint distributions (see code above). My goal is to create a utility class that streamlines this process, avoiding the need for manual loops and stacking, and boost the performance by being purely in JAX. This is my approach, but I'm open to feedback and corrections - perhaps I'm mistaken!
My gut feeling is that there likely won't be a large performance improvement unless the list of marginal distributions is very large, e.g., because there is still a loop to infer the shapes? Have you already run benchmarks for the two alternatives?
If the list is very large, it might make sense to combine distributions that have the same class, e.g., batch all the Normal
distributions in marginal_distributions
, batch all the Gamma
distributions, etc.?
I haven't benchmarked it yet! I will try it as soon as possible.
Perhaps I'm missing something, but the annoyance I have with relying on explicitly factorising the guide is that 1) conditioning becomes a lot of hassle, 2) it presumably forces computation to be performed sequentially (compared to autoregressive models that e.g. use masking). To me, it would be great if I could easily specify a joint distribution in the guide corresponding to a set of latents from the model.
Is there any way to "hack" this e.g. using factor/Delta statements?
Do you have an example of what this might look like in the guide, e.g., as pseudocode?
Edit: I think what I want is already supported - are there any issues with this sort of approach?
import jax.random as jr
from numpyro import sample, param, factor
from numpyro.distributions import Normal, Delta, constraints, MultivariateNormal
import numpyro
def model(obs=None):
a = sample("a", Normal())
b = sample("b", Normal())
sample("c", Normal(a+b), obs=obs)
def guide(*args, **kwargs):
loc = param("loc", jnp.zeros(2))
cholesky = param("cholesky", jnp.eye(2), contraint=constraints.lower_cholesky)
joint_dist = MultivariateNormal(loc, cholesky@cholesky.T)
joint_sample = sample("joint", joint_dist)
a, b = joint_sample
sample("a", Delta(a))
sample("b", Delta(b))
Yes, your example implementation is actually very close to the automatic guides here.
You can slightly shorten your code, if you like, by using a deterministic
site like so.
a, b = joint_sample
numpyro.deterministic("a", a)
numpyro.deterministic("b", b)
Doesn't save you many key strokes, but it may be more explicit?
It would be nice to include
joint_sample = sample("joint", joint_dist, infer={'is_auxiliary': True})
Some objectives have an explicit check that sample sites in guide need to be available in the model.
Closed because the issue seems to have been resolved.
Description
I request for a utility which allows to create a joint distribution from independent distributions like this:
$$p_{X_1,X_2,\cdots,X_n}(x_1,x_2,\cdots,xn)=p{X_1}(x1)p{X_2}(x2)\cdots p{X_n}(x_n)$$
Most importantly we should be able to use the
log_prob
method along withsample
.My implementation
```python # Copyright 2023 The GWKokab Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from jax import lax, numpy as jnp from jaxtyping import PRNGKeyArray from numpyro import distributions as dist from numpyro.util import is_prng_key class JointDistribution(dist.Distribution): r"""Joint distribution of multiple marginal distributions.""" def __init__(self, *marginal_distributions: dist.Distribution) -> None: r""" :param marginal_distributions: A sequence of marginal distributions. """ self.marginal_distributions = marginal_distributions self.shaped_values = tuple() batch_shape = lax.broadcast_shapes(*tuple(d.batch_shape for d in self.marginal_distributions)) k = 0 for d in self.marginal_distributions: if d.event_shape: self.shaped_values += (slice(k, k + d.event_shape[0]),) k += d.event_shape[0] else: self.shaped_values += (k,) k += 1 super(JointDistribution, self).__init__( batch_shape=batch_shape, event_shape=(k,), validate_args=True, ) def log_prob(self, value): log_probs = jax.tree.map( lambda d, v: d.log_prob(value[..., v]), self.marginal_distributions, self.shaped_values, is_leaf=lambda x: isinstance(x, dist.Distribution), ) log_probs = jnp.sum(jnp.asarray(log_probs).T, axis=-1) return log_probs def sample(self, key: PRNGKeyArray, sample_shape: tuple[int, ...] = ()): assert is_prng_key(key) keys = tuple(jax.random.split(key, len(self.marginal_distributions))) samples = jax.tree.map( lambda d, k: d.sample(k, sample_shape).reshape(*sample_shape, -1), self.marginal_distributions, keys, is_leaf=lambda x: isinstance(x, dist.Distribution), ) samples = jnp.concatenate(samples, axis=-1) return samples ```