loft-br / xgboost-survival-embeddings

Improving XGBoost survival analysis with embeddings and debiased estimators
https://loft-br.github.io/xgboost-survival-embeddings/
Apache License 2.0
321 stars 53 forks source link

AttributeError: 'XGBSEBootstrapEstimator' object has no attribute 'get_neighbors' #67

Open gustavjandrup opened 1 year ago

gustavjandrup commented 1 year ago

Code sample


xgbse_model = XGBSEStackedWeibull(xgb_params = DEFAULT_PARAMS)
bootstrap_estimator = XGBSEBootstrapEstimator(xgbse_model)
bootstrap_estimator.fit(X_train, y_train, time_bins=TIME_BINS, persist_train = True, index_id = X_train.index)

neighbors = bootstrap_estimator.get_neighbors(
    query_data=above_1,
    index_data=X_train,
    n_neighbors=10
)

print(neighbors)

Problem description

I try to get neighbors of observation to get samples for local explainability and get SHAP values with a bootstrap estimator. But I get this error message: AttributeError: 'XGBSEBootstrapEstimator' object has no attribute 'get_neighbors'.

Expected behavior

Since your documentation states that BaseEstimator is the Base class for all estimators in xgbse, I would expect .get_neigbors to work with all estimators in XGBSE.

Possible solutions

In your documentation, I can see that you never tested .get_neighbors with BootstrapEstimator. That would be a good place to start.

Also, if I have misunderstood something about your BootstrapEstimator please let me know :)

gdmarmerola-loft commented 1 year ago

Hey @gustavjandrup! Currently there is not support for . get_neighbors() in BoostrapEstimator. To do that we would have to devise a strategy for combining neighbors of all underlying estimators.

My recommendation would be to access each estimator via .estimators_ separately, and devise a score function to combine all neighbors. One idea is to count the number of times two samples were neighbors in each estimator, and rank them, with the closest neighbor being the one that showed up as a neighbor in most estimators.

Thanks!