autonlab / auton-survival

Auton Survival - an open source package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Events
http://autonlab.github.io/auton-survival
MIT License
323 stars 75 forks source link

Error in the example notebook: 'RandomSurvivalForest' object has no attribute 'event_times_' #131

Open raheems opened 1 year ago

raheems commented 1 year ago

There is an arror in the RSF model in this notebook Survival Regression with Auton-Survival.ipynb

Here's the reproducible code

import pandas as pd
import sys
sys.path.append('../')

from auton_survival.datasets import load_dataset

Load data and features

# Load the SUPPORT dataset
outcomes, features = load_dataset(dataset='SUPPORT')

# Identify categorical (cat_feats) and continuous (num_feats) features
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 
             'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 
             'glucose', 'bun', 'urine', 'adlp', 'adls']

Preprocess

import numpy as np
from sklearn.model_selection import train_test_split

# Split the SUPPORT data into training, validation, and test data
x_tr, x_te, y_tr, y_te = train_test_split(features, outcomes, test_size=0.2, random_state=1)
x_tr, x_val, y_tr, y_val = train_test_split(x_tr, y_tr, test_size=0.25, random_state=1) 

print(f'Number of training data points: {len(x_tr)}')
print(f'Number of validation data points: {len(x_val)}')
print(f'Number of test data points: {len(x_te)}')
from auton_survival.preprocessing import Preprocessor

# Fit the imputer and scaler to the training data and transform the training, validation and test data
preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') 
transformer = preprocessor.fit(features, cat_feats=cat_feats, num_feats=num_feats,
                                one_hot=True, fill_value=-1)
x_tr = transformer.transform(x_tr)
x_val = transformer.transform(x_val)
x_te = transformer.transform(x_te)

Fit RSF

from auton_survival.estimators import SurvivalModel
from auton_survival.metrics import survival_regression_metric
from sklearn.model_selection import ParameterGrid

# Define parameters for tuning the model
param_grid = {'n_estimators' : [100, 300],
              'max_depth' : [3, 5],
              'max_features' : ['sqrt', 'log2']
             }

params = ParameterGrid(param_grid)

# Define the times for tuning the model hyperparameters and for evaluating the model
times = np.quantile(y_tr['time'][y_tr['event']==1], np.linspace(0.1, 1, 10)).tolist()

# Perform hyperparameter tuning 
models = []
for param in params:
    model = SurvivalModel('rsf', random_seed=8, n_estimators=param['n_estimators'], max_depth=param['max_depth'], max_features=param['max_features'])

    # The fit method is called to train the model
    model.fit(x_tr, y_tr)

    # Obtain survival probabilities for validation set and compute the Integrated Brier Score 
    predictions_val = model.predict_survival(x_val, times)
    metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr)
    models.append([metric_val, model])

# Select the best model based on the mean metric value computed for the validation set
metric_vals = [i[0] for i in models]
first_min_idx = metric_vals.index(min(metric_vals))
model = models[first_min_idx][1]

AttributeError

AttributeError                            Traceback (most recent call last)
File <command-1063050024682857>:25
     22 model.fit(x_tr, y_tr)
     24 # Obtain survival probabilities for validation set and compute the Integrated Brier Score 
---> 25 predictions_val = model.predict_survival(x_val, times)
     26 metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr)
     27 models.append([metric_val, model])

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-677c5cec-55bc-4bf2-85ea-946a24e7ad0b/lib/python3.9/site-packages/auton_survival/estimators.py:699, in SurvivalModel.predict_survival(self, features, times)
    697   return _predict_cph(self._model, features, times)
    698 elif self.model == 'rsf':
--> 699   return _predict_rsf(self._model, features, times)
    700 elif self.model == 'dsm':
    701   return _predict_dsm(self._model, features, times)

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-677c5cec-55bc-4bf2-85ea-946a24e7ad0b/lib/python3.9/site-packages/auton_survival/estimators.py:477, in _predict_rsf(model, features, times)
    472   times = [float(times)]
    474 survival_predictions = model.predict_survival_function(features.values,
    475                                                        return_array=True)
    476 survival_predictions = pd.DataFrame(survival_predictions,
--> 477                                     columns=model.event_times_).T
    479 return __interpolate_missing_times(survival_predictions, times)

AttributeError: 'RandomSurvivalForest' object has no attribute 'event_times_'
raheems commented 1 year ago

I think I've found the solution.

On line 477 in estimators.py change from

columns=model.event_times_).T

to

columns=model.unique_times_).T

Is that correct?