ragulpr / wtte-rnn

WTTE-RNN a framework for churn and time to event prediction
MIT License
762 stars 186 forks source link

nan loss function when b approaches 0 #39

Open gm-spacagna opened 6 years ago

gm-spacagna commented 6 years ago

I have tried to solve the problem with nan loss and I found this trick to be helpful: adding the epsilon constant to the argument of np.log:

loglikelihoods = u * \
        K.log(K.exp(hazard1 - hazard0) - 1.0 + epsilon) - hazard1

This way when b ~ 0, thus hazard1 = hazard0, the logarithm is always defined.

https://github.com/ragulpr/wtte-rnn/blob/c0075a70efebf96ce022cad2b53dd334b9449c9a/python/wtte/wtte.py#L202

FMArduini commented 5 years ago

I solved it in the same way as you did. One issue I found however was that the loss still went to NaNs when training on GPUs (I suspect due to GPU float32 constraints but I'm no expert here).

To make it run on the GPU I replaced epsilon with 1e-6 instead:

exp = k.exp(hazard1 - hazard0) + 1e-6
log = k.log(exp - 1)
loglk = -1 * k.mean((u_ * log) - hazard1)

and it seemed to work..

edit* : When I was looking into this I was observing the loss function output using two different functions. Both did the same operations, one used tensorflow the other one used numpy. I used tf.float32, tf.float64, np.float32 and np.float64 dtype values of alphas and betas and I ended up with nans only with the tf.float32 option.

ragulpr commented 5 years ago

Great @FedericoNutmeg and relevant to https://github.com/ragulpr/wtte-rnn/issues/51.

Currently (on develop+master) it looks like https://github.com/ragulpr/wtte-rnn/blob/26612657ee0b14fa1a33f8da6ed28018e27cbe98/python/wtte/wtte.py#L169

Where epsilon is K.epsilon() which I think defaults to whatever's in your .keras json. I suggest changing it which should actually be warned for but the current message is wrong and stupid, didn't have time to test it yet. Try keras.backend.set_epsilon(1e-6) and it should behave as you suggested.

In the numeric stability tests I might actually be using float64, this should be updated.