pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.14k stars 235 forks source link

Grads w.r.t. weights of `MixtureGeneral` Distribution are giving `nan`s #1870

Open Qazalbash opened 2 days ago

Qazalbash commented 2 days ago

Hi,

We have created some models where we estimate the weights of the MixtureGeneral distribution. However, when computing the gradient of this argument, we are encountering nan values. We enabled jax.config.update("debug_nan", True) to diagnose the issue, and it pointed to the following line:

https://github.com/pyro-ppl/numpyro/blob/8e9313fd64a34162bc1c08b20ed310373e82e347/numpyro/distributions/mixtures.py#L152

I suspect that after the implementation of https://github.com/pyro-ppl/numpyro/pull/1791, extra care is needed to handle inf and nan values, possibly by using a double where for a safe logsumexp.

[!IMPORTANT] This is an urgent issue, so a prompt response would be greatly appreciated.

fehiepsi commented 2 days ago

You can add jax.debug.print(...) to inspect the component log probs. If all of the component log probs are -inf, nan will happen.

Qazalbash commented 2 days ago

So I checked and I got no nans but some infs not all.

Let me explain what I am trying to do, suppose we have a mixture model with $p_i(x|\Lambda)$ as the probability of each distribution. We also have $R_i$ as the scaling factor of each component, they are usually greater than 1. To make them work with MixtureGeneral we normalize them and pass it to the CategoricalProb as a mixing distribution.

So the compute happens like,

$$ \log(p(x|\Lambda))=\log\left(\sum_{i=1}^{n}R_i pi(x|\Lambda)\right) \iff \log(p(x|\Lambda))=\log\left(\sum{j=1}^{n}Rj\right)+\log\left(\sum{i=1}^{n}\frac{Ri}{\sum{j=1}^{n}R_j}p_i(x|\Lambda)\right) $$

We are trying to estimate the $\log(R_i)$ along with some other parameters. We do it like,

model = ... # Pass parameters to initialize the MixtureGeneral object
log_rate = ... # array of log of rates
log_sum_of_rates = jax.nn.logsumexp(log_rates, axis=-1) # log(R_0 + R_1 + ... + R_n)
mixing_probs = jax.nn.softmax(log_rates, axis=-1) # Normalized rates [R_0 / sum(R_i), R_1 / sum(R_i), ..., R_n / sum(R_i)]
model._mixing_distribution = CategoricalProbs(probs=mixing_probs)
log_p = model.log_prob(y) # Calculate probability
log_p = log_p + log_sum_of_rates # adding back the extra we subtracted to normalize