pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 234 forks source link

[FR] Support for different supports in component distributions for mixture models #1780

Closed Qazalbash closed 4 months ago

Qazalbash commented 5 months ago

Description

I am submitting a feature request to enhance the capabilities of Numpyro in handling mixture models with component distributions having different supports. Currently, when attempting to construct a mixture model using numpyro.distributions.MixtureGeneral, all component distributions are required to have the same support.

In my specific use case, I am implementing a truncated power law mixture model where one of the component distributions utilizes a truncated power law model with a high-mass cutoff represented by the Heaviside step function. The other component distribution involves a Gaussian distribution.

The formulation of the model is as follows:

$$ p(m_1\mid\theta) \propto (1-\lambda)A(\theta)m1^{-\alpha}\Theta(m{\text{max}}-m_1) + \lambda B(\theta)\exp\left(-\frac{(m_1-\mu_m)^2}{2\sigma_m^2}\right) $$

Where $\theta$ represents the parameter vector $\left(\lambda,\alpha,m_{\text{min}},\deltam,m{\text{max}},\mu_m,\sigma_m\right)$, $A(\cdot)$ and $B(\cdot)$ are normalization factors.

To implement this model efficiently, it is essential to have the flexibility to use component distributions with varying supports within the same mixture model.

Therefore, I kindly request the implementation of this feature to enable users to construct mixture models with component distributions having different supports. If there is any solution which does not requires new feature, will be great.

fehiepsi commented 5 months ago

Hi @Qazalbash, you can follow examples in the Discrete Latent Variables section e.g. Gaussian Mixture Model. There we use enumeration for mixture models.

zmbc commented 5 months ago

@fehiepsi Apologies if I'm missing something obvious, but is there a class of models where the enumeration trick can't be used?

As a toy example, let's say our data generation process looks like:

def data_generation_process(mu):
    x = np.random.normal(loc=mu, scale=1, size=10_000)

    for _ in range(5):
        x = np.where((x > 0.4), x - 0.1, x)

    return x

and then we want to infer the mu parameter from data. I tried this using NumPyro:

def difficult_without_mixtures(y=None):
    x = dist.Normal(loc=numpyro.param("mu", 0.0), scale=1)

    for _ in range(5):
        eligible_for_shift = (1 - x.cdf(0.4))
        x_eligible = dist.TruncatedDistribution(x, low=0.4)
        x_ineligible = dist.TruncatedDistribution(x, high=0.4)
        x = dist.Mixture(
            dist.Categorical([eligible_for_shift, 1 - eligible_for_shift]),
            [
                dist.TransformedDistribution(x_eligible, dist.transforms.AffineTransform(loc=-0.1, scale=1)),
                x_ineligible
            ]
        )

    numpyro.sample("obs", x, obs=y)

but got the error: ValueError: All component distributions must have the same support. Yet, I see no inherent reason why my code here couldn't work in principle; for example, PyMC allows mixtures with components with different support. I could implement a custom distribution class (like this) but at that point I'd be writing the same code NumPyro would need to add this feature, I think.

fehiepsi commented 5 months ago

I think we can allow MixtureGeneral take a support argument - we will skip the check if support is True. In this case, we need to enable validate_args of each component so that its log_prob will be -inf for out-of-domain samples. Do you want to make a PR for it?

Qazalbash commented 5 months ago

@fehiepsi I can try!

zmbc commented 4 months ago

Thank you both so much! 🎉