catboost / catboost

A fast, scalable, high performance Gradient Boosting on Decision Trees library, used for ranking, classification, regression and other machine learning tasks for Python, R, Java, C++. Supports computation on CPU and GPU.
https://catboost.ai
Apache License 2.0
8k stars 1.18k forks source link

Get splits for nodes in trees generated in CatBoost as Pandas DataFrame #2501

Open ashish-wai opened 12 months ago

ashish-wai commented 12 months ago

Problem: I'm working with CatBoost and I need to extract information about the trees generated by a trained CatBoost model in the form of a dataframe, similar to how it can be done with XGBoost using model.get_booster().trees_to_dataframe(). This information is helpful for further analysis and visualization.

Is there an equivalent method or approach in CatBoost to achieve this? If so, could you please provide a code example or guidance on how to obtain a dataframe of the trees from a CatBoost model?

Possibly these values should be there in the DataFrame or any other object with required information Tree Node ID Feature Split Yes No Missing Gain Cover

from xgboost import XGBRegressor

model = XGBRegressor(n_estimators=20, n_jobs=10,
                        objective='reg:squarederror')

model.fit(X[0]['train'], Y['train'])
trees = model.get_booster().trees_to_dataframe()

image I want to get a similar DF with all details in catboost:

from catboost import CatBoostRegressor

model = CatBoostRegressor( iterations=10,
                            thread_count=10)

model.fit(X['train'], Y['train'], cat_features=range(n_num_features, n_features))

catboost version:0.24.4 Operating System: Ubuntu 18.04 CPU: - GPU: -

inti4digbi commented 6 months ago

Hi, I have not found how to do this.

I found this potentially useful, but still cannot get it to work: https://medium.com/@joachimiak.krzysztof/extracting-trees-from-gbm-models-as-data-frames-ce37f4c08ba6

many thanks in advance

Evgueni-Petrov-aka-espetrov commented 5 months ago

hi @inti4digbi

from catboost import CatBoostRegressor

# Fitting a model
X, y = make_regression()
cb_reg = CatBoostRegressor(iterations=1).fit(X, y, verbose=False)

# Saving model
cb_reg.save_model('cb.dump', 'json')

print(json.load(open('cb.dump'))["oblivious_trees"])

prints a json array of trees like shown below (here it contains one tree) an oblivious tree of depth d has 2**d leaves ["leaf_values"][i] is prediction in leaf i, ["leaf_weights"][i] is number of samples in leaf i ["splits"] is list of splits of length d ["splits"]["border"] is split border, ["splits"]["float_feature_index"] is number of split feature, ["splits"][split_index"] is split level in the tree

[
{
  "leaf_values": [
    -7.475223660469055,
    -8.40327274799347,
    -10.556573987007141,
    -25.543860054016115,
    0,
    0,
    0,
    -2.6524535417556763,
    -27.458484252293903,
    22.67806537946065,
    -20.60556435585022,
    30.966007232666016,
    0,
    11.388235926628113,
    5.983312964439392,
    2.396146217981974,
    0,
    0,
    -38.0130708694458,
    28.552173852920532,
    0,
    0,
    -11.534627079963684,
    0,
    -44.91376667022705,
    -16.05863755941391,
    -27.97081768512726,
    -21.681302626927693,
    3.106310248374939,
    27.364540672302248,
    0,
    70.00350038210551,
    0,
    11.622217059135437,
    4.541542172431946,
    -7.182486152648925,
    0,
    7.111119627952576,
    -35.49035847187042,
    27.451266686121624,
    0,
    23.92159298488072,
    -64.15807361602783,
    22.18560619354248,
    0,
    40.62969183921814,
    -7.157690167427063,
    33.26015079021454,
    0,
    -17.732578468322753,
    -35.640793037414554,
    3.948658068974813,
    -21.542970776557922,
    -18.490345191955566,
    0,
    0,
    0,
    3.8491021394729614,
    -19.747035026550293,
    32.156920512517296,
    0,
    0,
    -12.512924313545227,
    16.398515892028808
  ],
  "leaf_weights": [
    1,
    1,
    1,
    2,
    0,
    0,
    0,
    1,
    3,
    3,
    3,
    2,
    0,
    1,
    1,
    3,
    0,
    0,
    2,
    3,
    0,
    0,
    1,
    0,
    2,
    5,
    1,
    3,
    1,
    2,
    0,
    3,
    0,
    1,
    1,
    2,
    0,
    1,
    1,
    3,
    0,
    4,
    2,
    12,
    0,
    3,
    1,
    1,
    0,
    2,
    2,
    3,
    1,
    2,
    0,
    0,
    0,
    1,
    6,
    3,
    0,
    0,
    1,
    2
  ],
  "splits": [
    {
      "border": -0.5437842607498169,
      "float_feature_index": 43,
      "split_index": 3,
      "split_type": "FloatFeature"
    },
    {
      "border": -0.5657655000686646,
      "float_feature_index": 39,
      "split_index": 2,
      "split_type": "FloatFeature"
    },
    {
      "border": 0.6844918727874756,
      "float_feature_index": 81,
      "split_index": 4,
      "split_type": "FloatFeature"
    },
    {
      "border": -0.46684354543685913,
      "float_feature_index": 26,
      "split_index": 0,
      "split_type": "FloatFeature"
    },
    {
      "border": 0.17201551795005798,
      "float_feature_index": 29,
      "split_index": 1,
      "split_type": "FloatFeature"
    },
    {
      "border": -0.11797700822353363,
      "float_feature_index": 89,
      "split_index": 5,
      "split_type": "FloatFeature"
    }
  ]
}
]