Closed ad05bzag closed 4 years ago
I uploaded a modification of the sklearn.tree.export module to this gist. It allows you to use the plot_tree
function with an instance of SurvivalTree
.
plot_tree(survival_tree, feature_names=feature_names, impurity=False, label="none")
cool! thank you so much. I ended up making it work using three methods: your suggested one , export_graphviz (graphviz library), dtreeviz (explained.ai package).
Your worked like so:
rsf = RandomSurvivalForest() plot_tree(rsf.estimators_[0], feature_names=X_train.iloc[:,2:].columns.tolist(), impurity=False, label="none")
For two latter methods, it came down to just subsetting the estimators, so like
from graphviz import Source
from sklearn.tree import export_graphviz
gbs=GradientBoostingSurvivalAnalysis()
feature_names=X_train.iloc[:,2:].columns.tolist()
target_names=df['death'].unique()
graph=Source(export_graphviz(gbs.estimators_[0][0],
feature_names=feature_names,
class_names=target_names,
filled=True,
rounded=True))
from dtreeviz.trees import *
viz = dtreeviz(gbs.estimators_[0][0], X_trainRF, np.array([x[1] for x in y_trainRF]), target_name='time to event', feature_names=X_train.iloc[:,2:].columns.tolist())
Q: what I have been thinking about is that the innate limitation of all these methods is that we are checking only a single tree in all three methods. Which seems lacking? Both for gradient boosting and random forests i have >100 estimators. How to give a more cohesive, inclusive picture?
I'm not aware of a method that visualizes ensembles of trees in a easy to understand way.
@sebp Thanks for this Gist, it's very useful !
However, i'm having troubles showing impurity, every value is "inf". Do you know if it is possible to display the logrank value for each node? Thank you.
@julienbeisel The infinity values is due to the underlying tree splitting implementation of scikit-learn. They define impurity on a per-node level, but the log-rank split criterion is only defined per split, not node. Unfortunately, scikit-learn doesn't seem to expose the impurity_improvement
part to Python, which is where the log-rank statistic is computed.
Thanks for your answer @sebp.
I ended up computing the log-rank value per split by modifying your Gist using the compare_survival
function. It allows me to extract sub-profiles from the tree.
cool! thank you so much. I ended up making it work using three methods: your suggested one , export_graphviz (graphviz library), dtreeviz (explained.ai package).
Your worked like so:
rsf = RandomSurvivalForest() plot_tree(rsf.estimators_[0], feature_names=X_train.iloc[:,2:].columns.tolist(), impurity=False, label="none")
For two latter methods, it came down to just subsetting the estimators, so like
- graphviz:
from graphviz import Source from sklearn.tree import export_graphviz gbs=GradientBoostingSurvivalAnalysis() feature_names=X_train.iloc[:,2:].columns.tolist() target_names=df['death'].unique() graph=Source(export_graphviz(gbs.estimators_[0][0], feature_names=feature_names, class_names=target_names, filled=True, rounded=True))
- dtreeviz
from dtreeviz.trees import * viz = dtreeviz(gbs.estimators_[0][0], X_trainRF, np.array([x[1] for x in y_trainRF]), target_name='time to event', feature_names=X_train.iloc[:,2:].columns.tolist())
Q: what I have been thinking about is that the innate limitation of all these methods is that we are checking only a single tree in all three methods. Which seems lacking? Both for gradient boosting and random forests i have >100 estimators. How to give a more cohesive, inclusive picture?
On graphviz , the solution doesn't seem to work on my end if I use gbs.estimators[0][0] the error is TypeError: 'SurvivalTree' object does not support indexing
and with gbs.estimators[0] I get AttributeError: 'SurvivalTree' object has no attribute 'criterion'
I haven't updated the code in a long time, so it's likely that it won't work with more recent version of sciki-learn/scikit-survival anymore.
@julienbeisel how did you modify the gist?
I'd like to use this for inspecting the basic SurvivalTree model, but it's not reporting the value for all nodes (which I presume would be a risk/hazard?)
Hi!
would you have any advice on how to visualize decision path / decision trees from the ensemble survival model methods (either RF or Gradient Boosting)?