jaredleekatzman / DeepSurv

DeepSurv is a deep learning approach to survival analysis.
MIT License
584 stars 173 forks source link

Error in predict_risk(test_data) #58

Open XPeriment2 opened 4 years ago

XPeriment2 commented 4 years ago

After following the notebook: model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn) i use model.get_concordance_index(**test_data) and i get a ci. Howver: predictions =model.predict_risk(test_data) prodcues an error of excetpion

image (1) image (2)

hdplsa commented 4 years ago

When you call model.predict_risk you should only pass the x matrix. I think you should be able to fix your problem by calling instead:

predictions = model.predict_risk(test_data["x"])