Open Qazalbash opened 2 days ago
You can add jax.debug.print(...)
to inspect the component log probs. If all of the component log probs are -inf, nan will happen.
So I checked and I got no nan
s but some inf
s not all.
Let me explain what I am trying to do, suppose we have a mixture model with $p_i(x|\Lambda)$ as the probability of each distribution. We also have $R_i$ as the scaling factor of each component, they are usually greater than 1. To make them work with MixtureGeneral
we normalize them and pass it to the CategoricalProb
as a mixing distribution.
So the compute happens like,
$$ \log(p(x|\Lambda))=\log\left(\sum_{i=1}^{n}R_i pi(x|\Lambda)\right) \iff \log(p(x|\Lambda))=\log\left(\sum{j=1}^{n}Rj\right)+\log\left(\sum{i=1}^{n}\frac{Ri}{\sum{j=1}^{n}R_j}p_i(x|\Lambda)\right) $$
We are trying to estimate the $\log(R_i)$ along with some other parameters. We do it like,
model = ... # Pass parameters to initialize the MixtureGeneral object
log_rate = ... # array of log of rates
log_sum_of_rates = jax.nn.logsumexp(log_rates, axis=-1) # log(R_0 + R_1 + ... + R_n)
mixing_probs = jax.nn.softmax(log_rates, axis=-1) # Normalized rates [R_0 / sum(R_i), R_1 / sum(R_i), ..., R_n / sum(R_i)]
model._mixing_distribution = CategoricalProbs(probs=mixing_probs)
log_p = model.log_prob(y) # Calculate probability
log_p = log_p + log_sum_of_rates # adding back the extra we subtracted to normalize
Hi,
We have created some models where we estimate the weights of the
MixtureGeneral
distribution. However, when computing the gradient of this argument, we are encounteringnan
values. We enabledjax.config.update("debug_nan", True)
to diagnose the issue, and it pointed to the following line:https://github.com/pyro-ppl/numpyro/blob/8e9313fd64a34162bc1c08b20ed310373e82e347/numpyro/distributions/mixtures.py#L152
I suspect that after the implementation of https://github.com/pyro-ppl/numpyro/pull/1791, extra care is needed to handle
inf
andnan
values, possibly by using a doublewhere
for a safelogsumexp
.