tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.1k forks source link

Mixture model (Zero-inflated) gives nan log-probablity at 0 #1224

Open davidlkl opened 3 years ago

davidlkl commented 3 years ago

To reproduce the result:

import tensorflow_probability as tfp
tfd = tfp.distributions
# Zero inflated log-normal
ziln = tfd.Mixture(
    cat=tfd.Categorical(probs=[0.9, 0.1]),
    components=[
    tfd.Deterministic(loc=0), 
    tfd.LogNormal(loc=6, scale=1),
])
print(ziln.log_prob(0))
# Expected output: log(0.9)
# Output: nan
print(ziln.log_prob(1))
# Expected output: -21.221523
# Output: -21.221523

I think the log_prob should be if x == 0: log(0.9) if x >0: log(0.1 * lognormal.pdf(x))

Source code of mixture model log_prob:

  def _log_prob(self, x):
    x = tf.convert_to_tensor(x, name='x')
    distribution_log_probs = [d.log_prob(x) for d in self.components]
    cat_log_probs = self._cat_probs(log_probs=True)
    final_log_probs = [
        cat_lp + d_lp
        for (cat_lp, d_lp) in zip(cat_log_probs, distribution_log_probs)
    ]
    concat_log_probs = tf.stack(final_log_probs, 0)
    log_sum_exp = tf.reduce_logsumexp(concat_log_probs, axis=[0])
    return log_sum_exp

Now the mixture model assumes the same x is always in the support of all the distribution components, hence having a non-nan log_prob. It cannot handle cases like zero-inflated gamma / log-normal where the continuous distribution has support (0, inf).

davmre commented 3 years ago

I think there are a few issues coming up here.

The immediate problem IMHO is that LogNormal(...).prob(0) returns nan rather than 0 (or equivalently that log_prob(0) returns nan rather than -inf). In general we don't want MixtureSameFamily to suppress NaNs, because they can also arise from non-support-related numerical issues where silently doing an incorrect calculation is not desirable. But MixtureSameFamily should do the right thing with -inf log probs.

We've generally been reluctant to guarantee the out-of-support behavior of distributions like LogNormal, because an explicit correction along the lines of log_prob = tf.where(x > 0, naive_logprob, -inf) creates a performance hit in the frequent case where you do know that you're only going to evaluate the distribution within its support. One solution could be to add a flag force_probs_to_zero_outside_support to LogNormal and similar distributions that would control this behavior.

More broadly, zero-inflated models are a tricky point in TFP. Fundamentally, there's no way to write a correct log_prob for them because a density function can't represent the point mass at zero (a point mass has infinite density, so the only possible prob/log_prob value at 0 is inf, but there's no way to indicate how much weight the point mass has). Even if you could represent it, the point mass would be invisible to gradient-based inference algorithms because it has no gradient, so zero-inflated models aren't an effective way to induce sparsity in latent variables; for that you'd want something like the horseshoe.

Can you share more context around what you're trying to do with the zero-inflated model? We might be able to suggest a more idiomatic approach that would avoid the issue you're running into.

davidlkl commented 3 years ago

Hi Davmre,

Thanks for your reply! I understand that semi-continuous distribution is a tricky point to handle.

I am modeling insurance claims which has large amount of zeros (>90%) and the claim severity follow a right-skewed distribution. So I am exploring Tweedie , ZI lognormal, ZI gamma etc. Instead of just getting the mean, I am also interested in the risk (variance) as a risk loading is normally applied to the pure premium.

brianwa84 commented 3 years ago

If stats on a large empirical sample would suffice, you could write tfp.stats.variance(zero_infl.sample(100_000)). We don't have variance implemented for mixtures afaik.

davidlkl commented 3 years ago

Hi @brianwa84 , I understand that by sampling I can estimate the variance. What I would like to achieve is to estimate the model parameters (probability of no claim, (expected claim severity + dispersion) of claim conditional on a claim is made) policy by policy, illustrated below:

My current approach (Only capturing Aleatoric Uncertainty):

input = Input()
dense1 = Dense(32)(input)
dense2 = Dense(32)(dense1)
# p, mu, sigma
parameters = Dense(3)(dense2)

model = Model(inputs=input, outputs=parameters)
model.compile(loss=zero_inflated_lognormal_loss)
model.fit(X, Y)

I would like to capture both Aleatoric & Epistemic Uncertainty, but not sure how to proceed with the current setting.

brianwa84 commented 3 years ago

Maybe you can write your own loss function, something like

def zero_inflated_lognormal_loss(parameters, labels): # I forget the ordering
  p, mu, sigma = tf.unstack(parameters, axis=-1)
  zeroness_loss = -tfd.Bernoulli(probs=p).log_prob(tf.equal(labels, 0))
  safe_labels = tf.where(tf.equal(labels, 0), 1., labels)
  nonzero_loss = tf.where(tf.equal(labels, 0), 0, -tfd.LogNormal(mu, sigma).log_prob(safe_labels))
  return zeroness_loss + nonzero_loss
EagleDangar commented 1 year ago

Hi @brianwa84 , this is kind of similar to https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py what is the prediction function for this ?