rajesh-lab / X-CAL

MIT License
12 stars 2 forks source link

Question about the Lognormal model #2

Open shi-ang opened 1 year ago

shi-ang commented 1 year ago

Hi,

I have a question about why do you choose to subtract 0.5 when you compute the lognormal distribution from the predicted parameters? To be specific, I'm referring to this line of code:

log_sigma = F.softplus(pre_log_sigma) - 0.5

in https://github.com/rajesh-lab/X-CAL/blob/721aec34b9f3c94fad732231fd3d0e9c4130fadc/util/distributions.py#LL41C45-L41C48

Did you find it empirically performs the best or is there any theoretical justification behind this value?

Thanks!

marikgoldstein commented 1 year ago

Hi,

Thanks for writing. The short answer is that this helps keep the LogNormal's log sigma large enough for numerical stability for optimization.

In more detail, this line of code appears in a function mapping X through a neural network to the parameters of a LogNormal distribution to model the time to event T:

Z ~ Normal(mu(X), sigma(X)), T = exp(Z).

The variable T will have unstable log likelihoods if sigma becomes too small. There are many possible solutions to this issue.

In our case, "pre_log_sigma" is the unconstrained output of a neural network, the Softplus() makes it positive, and -0.5 means that the smallest value log sigma can take on is -0.5, which means sigma will not be too small.

The number -0.5 could be replaced with a dataset-specific number capturing how much you expect each subject's conditional time to event to vary, or the smallest number that can be tolerated by your optimization problem.

Hope this helps, Thanks!