havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
780 stars 180 forks source link

fix: fixing the part3 of the nll loss parametrized with PMF #143

Closed sergiogvz closed 1 year ago

sergiogvz commented 1 year ago

First, thanks for your work and the repository. While doing some research with the DeepHitSingle model, we (@EileenHsieh) encountered an issue that might be a bug. Particularly, this potential bug concerns the loss function -- nll_pmf. We first noticed this when predicting survival curves. All of them had a sudden drop in the last bin (see example 1), which is not the bin introduced by pad_col because it is removed in the predict function.

Example 1: image

After checking the code and the loss function described in the paper (https://doi.org/10.1007/s10985-021-09532-6), we saw that part3 of the loss is lacking +1 in the idx_durations. That is, a censored sample i with duration k(t_i) (idx_duration) had survived until that time and the loss should consider the probability of survival after such time bin: cumulative sum from k(ti)+1 to m+1. This change agrees with expression 11 of the paper and it would fix the sudden drop mentioned before as shown in example 2.

Example 2: image

sergiogvz commented 1 year ago

Actually this fix is not correct, my bad!