havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

Applying SHAP to pycox models #126

Open VanishingRasengan opened 2 years ago

VanishingRasengan commented 2 years ago

Dear @havakv Sir,

first of all I want to say thank you for maintaining this amazing python package for survival analysis!

However, I am trying to apply SHAP (https://github.com/slundberg/shap), a tool for explainable AI, on pycox but it somehow does not seems to be compatible.

I followed the same steps like here (https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb?ref=https://githubhelp.com) and put this into the end:

import shap

explainer = shap.Explainer(model, x_train)
shaps = explainer(X_test)

The Error message is:

The passed model is not callable and cannot be analyzed directly with the given masker! Model: <pycox.models.cox.CoxPH object at 0x7f88195c3d90>

Is pycox not compatible with SHAP or did I apply the "Explainer" method wrongly?

Kind regards, Muralee

havakv commented 2 years ago

Hi, so we haven't made any effort to make pycox compatible with SHAP. I can see that there is some preliminary support for pytorch in SHAP, but I'm really not familiar with it. You could try to make it work by following som pytorch examples in SHAP and try to replicate that with pycox if you feel like it.

I guess full support would be nice, if someone were to look into it

teituroli commented 2 years ago

Dear @havakv Sir,

first of all I want to say thank you for maintaining this amazing python package for survival analysis!

However, I am trying to apply SHAP (https://github.com/slundberg/shap), a tool for explainable AI, on pycox but it somehow does not seems to be compatible.

I followed the same steps like here (https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb?ref=https://githubhelp.com) and put this into the end:

import shap

explainer = shap.Explainer(model, x_train)
shaps = explainer(X_test)

The Error message is:

The passed model is not callable and cannot be analyzed directly with the given masker! Model: <pycox.models.cox.CoxPH object at 0x7f88195c3d90>

Is pycox not compatible with SHAP or did I apply the "Explainer" method wrongly?

Kind regards, Muralee

Hi, Is there any chance you have found a solution?

All the best, T

thecml commented 2 years ago

I haven't tested myself with Pycox but had good results with scikit-survival and SHAP using the predict-trick:

import shap

explainer = shap.Explainer(model.predict, X_train, feature_names=feature_names)
shap_values = explainer(X_test)
hellorp1990 commented 2 years ago

@VanishingRasengan @teituroli Hi I am having a similar issue with SHAP. Do you have a solution on how to use SHAP for pycox?

teituroli commented 2 years ago

Hi @hellorp1990, I did not spend too much time trying to find a solution. I wrote my own loss function based on the Cox Proportional hazards model. You can find the very messy code here. Feel free to ask any questions. Then just using SHAP as:

import shap
explainer = shap.Explainer(model.predict, x_train)
shap_values = explainer(x_test)
LuisPerezLombardia commented 1 year ago

Hi, This seems to be working for me for the linear synthetic dataset generated in:

Katzman, J.L., Shaham, U., Cloninger, A. et al. (2018). DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Med Res Methodol, 18, 24.**

Here:

h(x) = x_0 + 2*x_1

I still need to finetune the hyperparameters; it is overfitting the training dataset.

import torch
import torchtuples as tt
import numpy as np
import shap
from pycox.models import CoxPH
from torch import nn

# Defining the network architecture
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, activation=nn.SELU, output_bias=output_bias)

# Setting up the optimizer
optimizer = tt.optim.SGD(lr=0.0002922, momentum=0.906, weight_decay=0.000399)
model_0a = CoxPH(net, optimizer)

# Training the model
epochs = 100
callbacks = [tt.callbacks.EarlyStopping()]
batch_size = 50
log = model_0a.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)
_ = log.plot()

# Predicting risk scores
def predict_risk(x):
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, dtype=torch.float32)
    with torch.no_grad():
        return model_0a.predict(x).cpu().numpy().astype(np.float32)

# SHAP evaluation
col = ['x0', 'x1', 'x2', 'x3','x4', 'x5', 'x6', 'x7', 'x8', 'x9']
explainer = shap.Explainer(predict_risk, x_train[:1000], feature_names=col)
shap_values = explainer(x_train[:1000])

# Plotting the SHAP values
shap.summary_plot(shap_values, x_train[:1000])

SHAP linear

teituroli commented 1 year ago

@LuisPerezLombardia, wonderful that you found a solution!