Open VanishingRasengan opened 2 years ago
Hi, so we haven't made any effort to make pycox compatible with SHAP. I can see that there is some preliminary support for pytorch in SHAP, but I'm really not familiar with it. You could try to make it work by following som pytorch examples in SHAP and try to replicate that with pycox if you feel like it.
I guess full support would be nice, if someone were to look into it
Dear @havakv Sir,
first of all I want to say thank you for maintaining this amazing python package for survival analysis!
However, I am trying to apply SHAP (https://github.com/slundberg/shap), a tool for explainable AI, on pycox but it somehow does not seems to be compatible.
I followed the same steps like here (https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb?ref=https://githubhelp.com) and put this into the end:
import shap explainer = shap.Explainer(model, x_train) shaps = explainer(X_test)
The Error message is:
The passed model is not callable and cannot be analyzed directly with the given masker! Model: <pycox.models.cox.CoxPH object at 0x7f88195c3d90>
Is pycox not compatible with SHAP or did I apply the "Explainer" method wrongly?
Kind regards, Muralee
Hi, Is there any chance you have found a solution?
All the best, T
I haven't tested myself with Pycox but had good results with scikit-survival and SHAP using the predict-trick:
import shap
explainer = shap.Explainer(model.predict, X_train, feature_names=feature_names)
shap_values = explainer(X_test)
@VanishingRasengan @teituroli Hi I am having a similar issue with SHAP. Do you have a solution on how to use SHAP for pycox?
Hi @hellorp1990, I did not spend too much time trying to find a solution. I wrote my own loss function based on the Cox Proportional hazards model. You can find the very messy code here. Feel free to ask any questions. Then just using SHAP as:
import shap
explainer = shap.Explainer(model.predict, x_train)
shap_values = explainer(x_test)
Hi, This seems to be working for me for the linear synthetic dataset generated in:
Katzman, J.L., Shaham, U., Cloninger, A. et al. (2018). DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Med Res Methodol, 18, 24.**
Here:
h(x) = x_0 + 2*x_1
I still need to finetune the hyperparameters; it is overfitting the training dataset.
import torch
import torchtuples as tt
import numpy as np
import shap
from pycox.models import CoxPH
from torch import nn
# Defining the network architecture
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
dropout, activation=nn.SELU, output_bias=output_bias)
# Setting up the optimizer
optimizer = tt.optim.SGD(lr=0.0002922, momentum=0.906, weight_decay=0.000399)
model_0a = CoxPH(net, optimizer)
# Training the model
epochs = 100
callbacks = [tt.callbacks.EarlyStopping()]
batch_size = 50
log = model_0a.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)
_ = log.plot()
# Predicting risk scores
def predict_risk(x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float32)
with torch.no_grad():
return model_0a.predict(x).cpu().numpy().astype(np.float32)
# SHAP evaluation
col = ['x0', 'x1', 'x2', 'x3','x4', 'x5', 'x6', 'x7', 'x8', 'x9']
explainer = shap.Explainer(predict_risk, x_train[:1000], feature_names=col)
shap_values = explainer(x_train[:1000])
# Plotting the SHAP values
shap.summary_plot(shap_values, x_train[:1000])
@LuisPerezLombardia, wonderful that you found a solution!
Dear @havakv Sir,
first of all I want to say thank you for maintaining this amazing python package for survival analysis!
However, I am trying to apply SHAP (https://github.com/slundberg/shap), a tool for explainable AI, on pycox but it somehow does not seems to be compatible.
I followed the same steps like here (https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb?ref=https://githubhelp.com) and put this into the end:
The Error message is:
The passed model is not callable and cannot be analyzed directly with the given masker! Model: <pycox.models.cox.CoxPH object at 0x7f88195c3d90>
Is pycox not compatible with SHAP or did I apply the "Explainer" method wrongly?
Kind regards, Muralee