Open Sreenath2019 opened 3 years ago
@Sreenath2019, DiCE should work with sksurv package. As long as you setup DiCE correctly and the sksurv models support a predict()/predict_proba() functions, DiCE should be able to generate counterfactuals. Please give a try and let us know.
FYI:- @amit-sharma
@gaugup Thanks for the reply.
I did try DiCE on a sksurv model which I built and got the message "# AttributeError: 'GradientBoostingSurvivalAnalysis' object has no attribute 'predict_proba' ".
It does have a predict() but not predict_proba() function. Reference here.
@Sreenath2019, could you paste the code executed and the stack trace? for classification scenarios, predict_proba() is a supported function. Strange this particular sksurv model doesn't support predict_proba().
Regards,
@gaugup here is the code I used to create the model as well to (try) creating the counterfactuals.
# Defining a boosted cox survival model
boostcox_tuned = GradientBoostingSurvivalAnalysis(n_estimators=64, min_samples_split=10, min_samples_leaf=2, max_features='auto', max_depth=11, loss='coxph')
boostcox_tuned.fit(train_data_x, train_data_y)
# Setup data
dice_data_source = pd.concat([train_data_x, train_data_y])
# sksurv target variable has two columns (Tenure & Survived(yes/no)). Removing Tenure to have single column for DiCE
dice_data_source = dice_data_source.drop(columns=['Tenure'])
continuous_features = dice_data_source.select_dtypes(include=np.number).columns.tolist()
dice_data = dice_ml.Data(dataframe=dice_data_source, continuous_features=continuous_features, outcome_name='Cancelled')
# Using sklearn backend
dice_model = dice_ml.Model(model=best_model, backend="sklearn")
# Using method=random for generating CFs
exp = dice_ml.Dice(dice_data, dice_model, method="random")
query_instance = train_data_x.iloc[1,:-1]
# Generate counterfactual examples
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
# Visualize counterfactual explanation
dice_exp.visualize_as_dataframe()
Last line throws the error message
# AttributeError: 'GradientBoostingSurvivalAnalysis' object has no attribute 'predict_proba'
Please note that this is a survival model (not a classification model with a definite yes/no outcome) and hence returns risk scores/hazard ratios, not probabilities.
Thank you
Thanks @Sreenath2019 for the sample code. If the output of the GradientBoostingSurvivalAnalysis model are ratios, then shouldn't this be explained via DiCE as regression task instead? You could set the 'desired_range' parameter in generate_counterfactuals() function and set the expected range accordingly?
Thank you @gaugup . I am not sure whether I understood you fully. Can you please give a short code example, assuming my desired range of ratios is between 0 & 1. This might give me a better picture. Thanks
First of all, thank you so much for creating this wonderful package. I am wondering if we can extend counterfactuals using DICE to survival models, especially those created using sksurv package (from sklearn family). I couldn't any existing package which supports counterfactuals for survival models and will be extremely helpful to have one. The paper "Counterfactual explanation of machine learning survival models" by Maxim S. Kovalev and Lev V. Utkin provides guidelines on the same. Thank you