havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
828 stars 193 forks source link

Experience in train/val/test partitions with survival data in real-world datasets #125

Open paulamartingonzalez opened 2 years ago

paulamartingonzalez commented 2 years ago

Hi @havakv ! I was in touch a few months ago to do a GNN implementation to use pycox losses in that setting. I am moving to a real-world dataset and I am struggling with the partitions.

I've done random partitions and consistently see that the validation loss while training is smaller than the train loss, even before starting to train. I am using LogisticHazard loss and discretising before. I've double checked there's no data leakage and when looking at the survival curves, they do look rather different. Could it be because of this? I've tried to stratify based on the "time" variable and although the issue gets better, I still see some weird behaviour in the CI index (better in the validation set than in the training set although being trained and optimised in the training set)

So I have two questions:

1) Do you have any experience when dealing with these kind of datasets and doing the partitions? Is normal what I am seeing or is there something I am missing? 2) Is it possible to use other metrics (like the time dependent AUC in scikit-surv) in the models trained with pycox? I'm comparing performance versus regular COX regression so I'd rather use the same implementation if possible

Thanks!

havakv commented 2 years ago

Hi, these are good questions, but I'm not sure how much help I can be. If you see very different survival predictions for your training and validation set, that means that the data sets are quite different right? So that would explain the difference in loss?

When you say "stratified", you talk about a balancing the training and validation set right? I.e., not like in "stratified cox regression"? If you're talking about balancing the datasets, I don't really have any experience doing this for survival data, but I would probably make sure to stratify both on the event time and the censoring time, so both distributions are balanced between the datasets. If you have different proportions of censored samples, you loss is likely to be different too.

You can monitor other metrics for you training and validation sets. For this you can make a custom callback by following the steps in #94 . Let me know if you're having problems making that work.

paulamartingonzalez commented 2 years ago

Hi @havakv thanks so much for your reply! When I say stratification I mean stratification in the train_test_split function to do the partition, yes! I actually followed your advise and found a workaround to stratify both in time to event and censoring time and the survival functions look much more comparable now. And the losses while training also make sense now :)

I am trying to compare the Neural Nets to basic COX regression performance using a few clinical features and it would be useful to use the scikit-surv metrics for that purpose but I am struggling to make them work. They expect the "Estimated risk of experiencing an event of test data." with dimensions (nsamples,) to be specified. But after training the LogisticHazard model, the prediction that I get using predict_surv are the survival functions (dimensions nsamples, discrete time points).

Is there any way to predict the "Estimated risk of experiencing an event of test data." using the LogisticHazard model to use the scikit-surv metrics? Thanks so much in advance!!

havakv commented 2 years ago

Hi again @paulamartingonzalez . I'm glad you were able to make it work.

A lot of survival scores has been developed with CoxPH regression in mind (as it's a very common model). Because of the proportional hazards assumption, we only need a single number to summarize the difference in survival between two individuals in a cox regression (the hazards of two individuals are always proportional and the survival functions will never cross). So the metrics developed for CoxPH naturally just use this single number, as it summarizes the risk of an individual. If you want to create a similar risk score from a non-proportional hazards model, you need to be a bit creative. I guess you could use something like the median survival (where the survival probability drops below 0.5) or any other quantile, or you could have a look at the suggestions in #33 and #41. You should give both of those a read as they cover some some for this topic.

If you're only interested in a good risk score, you might want to consider a survival model that provides this (like the Cox regression), as I don't expect the more flexible models like LogisticHazard, DeepHit, CoxTime, etc to improve on metrics requiring such scores. If you want to show an improvement in the survival predictions for your model compared to standard Cox, that is likely simpler by using something like the Brier Score or the Concordance index by Antolini, which both use the full survival predictions and can therefore show some of the problems of the proportional hazard models.

paulamartingonzalez commented 2 years ago

Thanks so much for the explanation @havakv ! I think your suggestion is what I will end up doing - using Brier score and CI index.

I am using scikit-surv for the COX regression I am using as baseline. I see that the Brier score formulation is the same as in pycox (taken from this paper). Nevertheless, the Concordance Index calculated there is different (this one and this one).

Would it be possible to compare the LogisticHazard model with the COX regression using one of the Concordance Indexes provided by scikit-surv to ensure fair comparison? Is there any way I can input the LogisticHazard results into any of those functions? Thanks in advance!

Update: I wouldn't mind using lifelines instead but it doesn't work either

havakv commented 2 years ago

So, for for the concordance_index_censored in scikit-surv, the three first arguments are event_indicator, event_time, estimate which are all 1D arrays. The two first should be trivial as they are regular event indicator and event time targets. The estimate, on the other hand, is a bit more tricky (as mentioned earlier) and needs to create from the survival estimates of the Logistic-Hazard. In #33 I wrote an example of how this can be done using an index of the survival predictions. A possibly better approach than choosing an arbitrary index could be to use a quantile for each individual survival prediction, i.e., the time at which the predicted survival crosses a specific value. If you set this value to 0.5 you would then be using the median survival time as your estimate. The downside of using this is that some survival predictions might not cross 0.5 for the time range of interest. I'll let you try to code this up yourself.

For the concordance_index_ipcw, the estimate needs to be created in the same way, but the two other arguments are structured arrays. There is probably some example in the scikit-surv docs that shows how these structured arrays can be created.

glgmartin commented 2 years ago

Hi @paulamartingonzalez, I've had the exact same problem with time dependant AUC (from sksurv) and pycox. I could make it work like so:

# start by defining a function to transform 1D arrays to the 
# structured arrays needed by sksurv
def transform_to_struct_array(times, events):
    return sksurv.util.Surv.from_arrays(events, times)

# then call the AUC metric from sksurv
cumulative_dynamic_auc(
    transform_to_struct_array(durations_train, events_train), 
    transform_to_struct_array(durations_test, events_test),
    model.predict(x_test).squeeze(),
    times)

Here model is any pycox model and I squeezed the predicted outputs to avoid sksurv checks for 2D target. times is a required argument of the AUC sksurv function. Sidenote : I'm not sure why but I always get really similar results from this and the concordance_td function in EvalSurv ... we're talking 4th decimal difference.

mahootiha-maryam commented 2 years ago

Thanks a lot @glgmartin for the code of calculating AUC with sksurv. I couldn't structure the arrays and this code helped me. Just I can add what is the times as the last input to cumulative_dynamic_auc, because you didnt mention, I hope it helps people who want to use it:

start_t = min(durations_test)
end_t = max(durations_test)
points_number = round((end_t - start_t)/model.predict(x_test).squeeze().shape[1])
times = np.arange(start_t,end_t,points_number)