mljar / supertree

Visualize decision trees in Python
https://mljar.com/supertree
GNU Affero General Public License v3.0
404 stars 9 forks source link

Error with categorical features #14

Closed Thornam closed 1 month ago

Thornam commented 1 month ago

When using categorical features that aren't binary, I experience an error in the show_tree().

image

As far as I can see, when trying to debug, it seems that the error is caused because splits on categorical features have more than one threshold, causing the threshold not to be a float variable.

pplonski commented 1 month ago

Hi @Thornam,

Thank you for reporting the issue. What algorithm are you using for trees? Could you please provide code to reproduce the issue? so we can fix the problem. Thank you!

Thornam commented 1 month ago

Hi, Thanks for the response

I'm using a LGBMRegressor, and have modified your diabetes example to create the error:

import numpy as np
import pandas as pd
import lightgbm as lgbm
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import load_diabetes
from supertree import SuperTree  # <- import supertree :)

# Load the diabetes dataset
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

# Create categorical feature
X_df = pd.DataFrame(X)
X_df['Cat'] = np.where(X_df[8] < -0.03, 1, np.where(X_df[8] > 0.03, 2, 3))
X_df['Cat'] = X_df['Cat'].astype('category')
X_df = X_df[[1, 'Cat']]

# Train model
model = lgbm.LGBMRegressor(**{'verbose': -1})
model.fit(X_df, y)

# Initialize supertree
super_tree = SuperTree(model,X_df, y)
# show tree with index 2 in your notebook
super_tree.show_tree(2)
Marchlak commented 1 month ago

Hi @Thornam , thanks for your issue. For now, I've moved the rounding of numbers to JavaScript, and that solved the problem in your example. I'm not familiar with decision trees involving categorical data, but I believe that the threshold for floating-point numbers simply becomes an integer (1, 2, 3, etc.). At least in the LGBM library in scikit-learn, it's similar, but if I'm wrong, please correct me. Please check if your example now works and if the visualization meets your expectations. If you encounter any more issues in Supertree, feel free to report them, and I'll be happy to fix them.

Thornam commented 1 month ago

Thanks! It does work now for categorical data with the LGBM, both in the example and in my original problem.

pplonski commented 1 month ago

Good job @Marchlak :1st_place_medal: