Open davidlkl opened 3 years ago
I think there are a few issues coming up here.
The immediate problem IMHO is that LogNormal(...).prob(0)
returns nan
rather than 0
(or equivalently that log_prob(0)
returns nan
rather than -inf
). In general we don't want MixtureSameFamily to suppress NaNs, because they can also arise from non-support-related numerical issues where silently doing an incorrect calculation is not desirable. But MixtureSameFamily should do the right thing with -inf
log probs.
We've generally been reluctant to guarantee the out-of-support behavior of distributions like LogNormal, because an explicit correction along the lines of log_prob = tf.where(x > 0, naive_logprob, -inf)
creates a performance hit in the frequent case where you do know that you're only going to evaluate the distribution within its support. One solution could be to add a flag force_probs_to_zero_outside_support
to LogNormal and similar distributions that would control this behavior.
More broadly, zero-inflated models are a tricky point in TFP. Fundamentally, there's no way to write a correct log_prob
for them because a density function can't represent the point mass at zero (a point mass has infinite density, so the only possible prob
/log_prob
value at 0 is inf
, but there's no way to indicate how much weight the point mass has). Even if you could represent it, the point mass would be invisible to gradient-based inference algorithms because it has no gradient, so zero-inflated models aren't an effective way to induce sparsity in latent variables; for that you'd want something like the horseshoe.
Can you share more context around what you're trying to do with the zero-inflated model? We might be able to suggest a more idiomatic approach that would avoid the issue you're running into.
Hi Davmre,
Thanks for your reply! I understand that semi-continuous distribution is a tricky point to handle.
I am modeling insurance claims which has large amount of zeros (>90%) and the claim severity follow a right-skewed distribution. So I am exploring Tweedie , ZI lognormal, ZI gamma etc. Instead of just getting the mean, I am also interested in the risk (variance) as a risk loading is normally applied to the pure premium.
If stats on a large empirical sample would suffice, you could write tfp.stats.variance(zero_infl.sample(100_000))
. We don't have variance implemented for mixtures afaik.
Hi @brianwa84 , I understand that by sampling I can estimate the variance. What I would like to achieve is to estimate the model parameters (probability of no claim, (expected claim severity + dispersion) of claim conditional on a claim is made) policy by policy, illustrated below:
My current approach (Only capturing Aleatoric Uncertainty):
input = Input()
dense1 = Dense(32)(input)
dense2 = Dense(32)(dense1)
# p, mu, sigma
parameters = Dense(3)(dense2)
model = Model(inputs=input, outputs=parameters)
model.compile(loss=zero_inflated_lognormal_loss)
model.fit(X, Y)
I would like to capture both Aleatoric & Epistemic Uncertainty, but not sure how to proceed with the current setting.
Maybe you can write your own loss function, something like
def zero_inflated_lognormal_loss(parameters, labels): # I forget the ordering
p, mu, sigma = tf.unstack(parameters, axis=-1)
zeroness_loss = -tfd.Bernoulli(probs=p).log_prob(tf.equal(labels, 0))
safe_labels = tf.where(tf.equal(labels, 0), 1., labels)
nonzero_loss = tf.where(tf.equal(labels, 0), 0, -tfd.LogNormal(mu, sigma).log_prob(safe_labels))
return zeroness_loss + nonzero_loss
Hi @brianwa84 , this is kind of similar to https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py what is the prediction function for this ?
To reproduce the result:
I think the log_prob should be if x == 0: log(0.9) if x >0: log(0.1 * lognormal.pdf(x))
Source code of mixture model log_prob:
Now the mixture model assumes the same x is always in the support of all the distribution components, hence having a non-nan log_prob. It cannot handle cases like zero-inflated gamma / log-normal where the continuous distribution has support (0, inf).