havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

AssertionError during c-index computation #46

Open lucapalazzi opened 4 years ago

lucapalazzi commented 4 years ago

Hi, I ran into the following assertion error when computing the c-index for the discrete MTLR method.

assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0]

I suppose the error is due to the fact that the maximum of test durations is 1628***, while the function gets in input a number between 0 and 490. This range (0, 490) is the result of applying the following:

num_durations = 50
scheme = 'quantiles'
labtrans = MTLR.label_transform(num_durations, scheme)

and

surv = model.interpolate(10).predict_surv_df(x_test)

As the parameter of the function interpolate increases, the number of grid points also increases and viceversa. It is a multiplication between num_durations and the parameter of the function interpolate. In the example I followed step by step, it was also pointed out that in the plot "the time scale is correct because we have set model.duration_index to be the grid points".

Thanks in advance, Luca

*** EDIT: I hadn't read the error carefully and now I understood what I was wrong (I partially saved the df deriving from surv and then applied the c-index on the entire test set). Unfortunately, I still haven't figured out how to fix the time scale problem in discrete models.

havakv commented 4 years ago

Hi! Thank you for the feedback! It's a little hart to debug this by only you explanation, so if you're still having problems, would it be possible to send me a short piece of code with some data that reproduce this bug? If you can't send me your data set, could you create a simple dataset that result in this bug? In that case I would be happy to take a look!

havakv commented 4 years ago

So, for completeness. The statement assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0] only checks that the number of samples (individuals or want one would call it) in the the test set is the same in the durations, events and surv_idx vector, and that we have the same number of survival predictions surv.shape[1]. You correctly identified this problem @lucapalazzi when your surv object did not match the samples in the test set.

I'm still not quite sure what your problem with the time scale in the discrete models is. The statement "the time scale is correct because we have set model.duration_index to be the grid points", just mean that the the index in the surv object is set to represent points in time rather than the discretization index. This is, however needed to compute the concordance, so any bug (or unintuitive api) needs to be fixed.

lucapalazzi commented 4 years ago

Hi. I found the error this morning and it has nothing to do with any api, but it was my oversight. In particular, the mismatch between saving the function result predict_surv_df(x) in a csv file and the subsequent loading in another script gave me problems. This step has canceled the column with the grid points, leaving only the indexes.

The error submitted did not occur with the continuous models since the indexes matched the grid points, but it did occur with the discrete ones. That's why by setting 50 grid points, in the csv file I found 491 rows with indexes in the range 0-490 and not in the range 0-1628.

I apologize for the haste in opening issues that turn out to be trivial problems, which can be solved by paying more attention. Thank you for your time

havakv commented 4 years ago

I'm happy you figured it out, and it is not a problem that you opened an issue. Clearly the assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0] needs to print an error message such that the user can better understand why it is failing. Opening issues helps with improving the API.