sebp / scikit-survival

Survival analysis built on top of scikit-learn
GNU General Public License v3.0
1.13k stars 213 forks source link

viz of ensemble models #110

Closed ad05bzag closed 4 years ago

ad05bzag commented 4 years ago

Hi!

would you have any advice on how to visualize decision path / decision trees from the ensemble survival model methods (either RF or Gradient Boosting)?

sebp commented 4 years ago

I uploaded a modification of the sklearn.tree.export module to this gist. It allows you to use the plot_tree function with an instance of SurvivalTree.

plot_tree(survival_tree, feature_names=feature_names, impurity=False, label="none")
ad05bzag commented 4 years ago

cool! thank you so much. I ended up making it work using three methods: your suggested one , export_graphviz (graphviz library), dtreeviz (explained.ai package).

Your worked like so: rsf = RandomSurvivalForest() plot_tree(rsf.estimators_[0], feature_names=X_train.iloc[:,2:].columns.tolist(), impurity=False, label="none") For two latter methods, it came down to just subsetting the estimators, so like

from graphviz import Source
from sklearn.tree import export_graphviz

gbs=GradientBoostingSurvivalAnalysis()
feature_names=X_train.iloc[:,2:].columns.tolist()
target_names=df['death'].unique()

graph=Source(export_graphviz(gbs.estimators_[0][0],
                             feature_names=feature_names,
                             class_names=target_names,
                             filled=True, 
                             rounded=True))

viz = dtreeviz(gbs.estimators_[0][0], X_trainRF, np.array([x[1] for x in y_trainRF]), target_name='time to event', feature_names=X_train.iloc[:,2:].columns.tolist())



Q: what I have been thinking about is that the innate limitation of all these methods is that we are checking only a single tree in all three methods. Which seems lacking? Both for gradient boosting and random forests i have >100 estimators. How to give a more cohesive, inclusive picture?
sebp commented 4 years ago

I'm not aware of a method that visualizes ensembles of trees in a easy to understand way.

julienbeisel commented 3 years ago

@sebp Thanks for this Gist, it's very useful !

However, i'm having troubles showing impurity, every value is "inf". Do you know if it is possible to display the logrank value for each node? Thank you.

sebp commented 3 years ago

@julienbeisel The infinity values is due to the underlying tree splitting implementation of scikit-learn. They define impurity on a per-node level, but the log-rank split criterion is only defined per split, not node. Unfortunately, scikit-learn doesn't seem to expose the impurity_improvement part to Python, which is where the log-rank statistic is computed.

julienbeisel commented 3 years ago

Thanks for your answer @sebp.

I ended up computing the log-rank value per split by modifying your Gist using the compare_survival function. It allows me to extract sub-profiles from the tree.

Machariajane commented 2 years ago

cool! thank you so much. I ended up making it work using three methods: your suggested one , export_graphviz (graphviz library), dtreeviz (explained.ai package).

Your worked like so: rsf = RandomSurvivalForest() plot_tree(rsf.estimators_[0], feature_names=X_train.iloc[:,2:].columns.tolist(), impurity=False, label="none") For two latter methods, it came down to just subsetting the estimators, so like

  • graphviz:
from graphviz import Source
from sklearn.tree import export_graphviz

gbs=GradientBoostingSurvivalAnalysis()
feature_names=X_train.iloc[:,2:].columns.tolist()
target_names=df['death'].unique()

graph=Source(export_graphviz(gbs.estimators_[0][0],
                             feature_names=feature_names,
                             class_names=target_names,
                             filled=True, 
                             rounded=True))
  • dtreeviz
from dtreeviz.trees import *

viz = dtreeviz(gbs.estimators_[0][0], X_trainRF, np.array([x[1] for x in y_trainRF]), target_name='time to event',
               feature_names=X_train.iloc[:,2:].columns.tolist())

Q: what I have been thinking about is that the innate limitation of all these methods is that we are checking only a single tree in all three methods. Which seems lacking? Both for gradient boosting and random forests i have >100 estimators. How to give a more cohesive, inclusive picture?

On graphviz , the solution doesn't seem to work on my end if I use gbs.estimators[0][0] the error is TypeError: 'SurvivalTree' object does not support indexing and with gbs.estimators[0] I get AttributeError: 'SurvivalTree' object has no attribute 'criterion'

sebp commented 1 year ago

I haven't updated the code in a long time, so it's likely that it won't work with more recent version of sciki-learn/scikit-survival anymore.

gorj-tessella commented 1 year ago

@julienbeisel how did you modify the gist?

gorj-tessella commented 1 year ago

I'd like to use this for inspecting the basic SurvivalTree model, but it's not reporting the value for all nodes (which I presume would be a risk/hazard?)