autonlab / auton-survival

Auton Survival - an open source package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Events
http://autonlab.github.io/auton-survival
MIT License
315 stars 74 forks source link

Applying DSM on competing risks data #95

Closed Bokyeong1001 closed 1 year ago

Bokyeong1001 commented 1 year ago

Hello!

I am trying to apply Deep Survival Machine to my dataset which has 23 events. And the results are not as good as I expected.

In the DSM paper, with SEER data, DSM and DeepHit c-index results are pretty comparable, but with my data DeepHit c-index results are about 10% better than DSM.

I made code based on your DSM example notebook code. I only changed the evaluation part since the example notebook code is for a single event.

Is it possible that DSM is unsuitable for data with many events? Or just because my code has a problem?

Here is my code.

`

from process_severance import make_data
from auton_survival.preprocessing import Preprocessor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import ParameterGrid
from auton_survival.models.dsm import DeepSurvivalMachines
from sksurv.metrics import concordance_index_ipcw, brier_score

data_path = './data/dummy_data.csv'
input, Y, features = make_data(data_path)

cat_feats = ["SEX1"]
num_feats = features
num_feats.remove('SEX1')

features = Preprocessor().fit_transform(input, cat_feats=cat_feats, num_feats=num_feats)

horizons = [0.25, 0.5, 0.75]
times = np.nanquantile(Y["event_time"], horizons).tolist()

x, t, e = input.to_numpy(), Y["event_time"].to_numpy(), Y["label"].to_numpy()

kf = KFold(n_splits=5, shuffle=True, random_state=1234)
fold = 0

for train_index, test_index in kf.split(x):
    x_train     = x[train_index]
    t_train     = t[train_index]
    e_train     = e[train_index]
    x_test      = x[test_index]
    t_test      = t[test_index]
    e_test      = e[test_index]

    (x_train, x_val, t_train,t_val, e_train,e_val)  = train_test_split(x_train, t_train, e_train, test_size=0.20, random_state=1234) 

    param_grid = {'k' : [3, 4, 6],
            'distribution' : ['LogNormal', 'Weibull'],
            'learning_rate' : [ 1e-4, 1e-3],
            'layers' : [ [], [100], [100, 100] ]
            }
    params = ParameterGrid(param_grid)

    models = []
    for param in params:
        model = DeepSurvivalMachines(k = param['k'],
                                    distribution = param['distribution'],
                                    layers = param['layers'])
        # The fit method is called to train the model
        model.fit(x_train, t_train, e_train, iters = 100, learning_rate = param['learning_rate'])
        models.append([[model.compute_nll(x_val, t_val, e_val), model, param]])
        #break
    best_model = min(models)

    out_risk = model.predict_risk(x_test, times)
    out_survival = model.predict_survival(x_test, times)

    for ev in range(23):
        cis = []
        brs = []
        e_train_new = (e_train == ev+1)
        e_test_new = (e_test == ev+1)

        et_train = np.array([(e_train_new[i], t_train[i]) for i in range(len(e_train_new))],
                        dtype = [('e', bool), ('t', float)])
        et_test = np.array([(e_test_new[i], t_test[i]) for i in range(len(e_test_new))],
                        dtype = [('e', bool), ('t', float)])

        for i, _ in enumerate(times):
            try:
                cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])
            except:
                cis.append(np.nan)
        try:
            brs.append(brier_score(et_train, et_test, out_survival, times)[1])
        except:
            brs.append([np.nan,np.nan,np.nan])

        for i, horizon in enumerate(horizons):
            print(f"For {horizon} quantile,")
            print("TD Concordance Index:", cis[i])
            print("Brier Score:", brs[0][i])
    fold = fold + 1

`

chiragnagpal commented 1 year ago

It's hard to be able to pin point what might be the source of the problem w/o having more information on the dataset..

In our experience DSM typically works atleast as well if not better than Deephit on most datasets we experimented with.

Have you tried tuning the DSMs temperature parameter? 'temp' perhaps between [1, 50, 100] ? (Default is 1000).

If that doesn't help, feel free to email me at and we can discuss offline.

chiragnagpal commented 1 year ago

@Bokyeong1001 i think there might be a bug in your code.

You aren't really using the model with the lowest nll, you're reusing the last model. You need to make sure you're using the model with the lowest neg log likelihood.

Bokyeong1001 commented 1 year ago

Thank you so much!

I will fix the bug and try again with tuning 'temp' parameter.

If it doesn't work, I will email with information about dataset.