Open aslitoj opened 3 years ago
Hi!
The Cox models (CoxPH, CoxCC and CoxTime) doesn't really work like the other model implemented here as there are some additional details needed for the non-parametric baseline hazards (which is estimated with model.compute_baseline_hazards
).
I think just unpacking the dataset should do the trick for you here, so for a dataset train_data
of size (torch.Size([16, 1, 128, 128]), (torch.Size([16]), torch.Size([16])))
, you should just need to run model.compute_baseline_hazards(*train_data)
.
If you want to build the bazeling hazards based on the full training set, you can concatenate all the batches in your dataloader dl_train
with the following code
train_data = tt.tuplefy([data for data in dl_train]).cat()
I guess there should be a way to estimate the baseline hazards by using the dataloaders directly, is this way is very unintuitive. Thanks for raising the issue!
Hello @havakv, Thank you for your reply. Your suggestion above solved the issue. I am glad that I am able to contribute. Lastly, thanks a lot for this great work!
Happy to help, and thank you for the kind words! I'll just let this issue stay open for a while to remind me that this is something that really should be improved in the future.
Hi, I met the same problem in Precidition with the CoxPH model and GNN net. The input of my data is Graph, but the target is tensor. When I ran the code model.compute_baseline_hazards()
, it got the following error:
ValueError: All objects in 'data' doest have the same type.
Could you please help me to fix this problem?
Hello,
I'd like to build a model which takes images and predicts overall survival time as continuous. For that reason, I followed the model shown in this jupyter notebook 04_mnist_dataloaders_cnn.ipynb by using CoxPH instead of LogisticHazards. However, I got 2 different errors. I am using databatchloader by the way.
When I tried with CoxPH and fit the model with:
callbacks = [tt.cb.EarlyStopping()] epochs = 100 verbose = True log = model.fit_dataloader(dl_train, epochs, callbacks, verbose, val_dataloader=dl_val)
Running this code :
net
is same with the sample notebook stated above.model = CoxPH(net, tt.optim.Adam(0.01)) surv = model.predict_surv_df(dl_test_x)
gave me this error:`--------------------------------------------------------------------------- ValueError Traceback (most recent call last)