havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
780 stars 180 forks source link

NLL PMF loss with partial PMF output #154

Open cirrostratus1 opened 1 year ago

cirrostratus1 commented 1 year ago

I am trying to create a model that outputs the PMF only for the first m time steps in the future, but can handle samples with survival > m. As far as I understand the paper, phi_m+1(x) is not a real output of the network but just set to 0, which makes sense at it can be inferred (by the softmax) from the rest of the networks output due to the sum to 1 constraint. I assume that's also the reason for pad_col (might be the answer to #152),

https://github.com/havakv/pycox/blob/0e9d6f9a1eff88a355ead11f0aa68bfb94647bf8/pycox/models/loss.py#L83

which adds the static output of 0 for phi_m+1. However, I do not understand the purpose of the exception

https://github.com/havakv/pycox/blob/0e9d6f9a1eff88a355ead11f0aa68bfb94647bf8/pycox/models/loss.py#L75-L78

as the function apparently can handle survival beyond m.

Could you maybe explain why this constraint would be still needed or if it could be safely removed?