jaredleekatzman / DeepSurv

DeepSurv is a deep learning approach to survival analysis.
MIT License
584 stars 173 forks source link

Save and reload the model? #42

Open EliotZhu opened 6 years ago

EliotZhu commented 6 years ago

Hi There: I just wonder if the load model function is working properly. I can see the weights and updates are saved and reloaded properly, by when calling the saved model to predict on the same data when training the model, the c-index is significantly lower, looks like the the model is not properly specified. It would be helpful to illustrate the correct way to save and reload a trained model.

Thanks.

allendorf commented 5 years ago

Hi, bumping this because I have the same issue. Reproduce by running the "DeepSurv Example" notebook and then saving the model using

model.save_model('bestparams.json', weights_file='bestweights.h5')

Then, run the following cell repeatedly. You will get a new c-index each time. I've gotten anywhere between 0.25 and 0.65 for the c-index by doing this.


model2 = deepsurv.deep_surv.load_model_from_json('bestparams.json', weights_fp='bestweights.h5')

if model2.standardize:
    model2.offset = train_data['x'].mean(axis = 0)
    model2.scale = train_data['x'].std(axis = 0)

x_train2, e_train2, t_train2 = model2.prepare_data(train_data)

compute_hazards = theano.function(inputs = [model2.X],outputs = -model2.partial_hazard)
partial_hazards2 = compute_hazards(x_train2)

ci_train = concordance_index(t_train2,partial_hazards2,e_train2) #from lifelines 
print(ci_train)