Closed Thornam closed 1 month ago
Hi @Thornam,
Thank you for reporting the issue. What algorithm are you using for trees? Could you please provide code to reproduce the issue? so we can fix the problem. Thank you!
Hi, Thanks for the response
I'm using a LGBMRegressor, and have modified your diabetes example to create the error:
import numpy as np
import pandas as pd
import lightgbm as lgbm
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import load_diabetes
from supertree import SuperTree # <- import supertree :)
# Load the diabetes dataset
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target
# Create categorical feature
X_df = pd.DataFrame(X)
X_df['Cat'] = np.where(X_df[8] < -0.03, 1, np.where(X_df[8] > 0.03, 2, 3))
X_df['Cat'] = X_df['Cat'].astype('category')
X_df = X_df[[1, 'Cat']]
# Train model
model = lgbm.LGBMRegressor(**{'verbose': -1})
model.fit(X_df, y)
# Initialize supertree
super_tree = SuperTree(model,X_df, y)
# show tree with index 2 in your notebook
super_tree.show_tree(2)
Hi @Thornam , thanks for your issue. For now, I've moved the rounding of numbers to JavaScript, and that solved the problem in your example. I'm not familiar with decision trees involving categorical data, but I believe that the threshold for floating-point numbers simply becomes an integer (1, 2, 3, etc.). At least in the LGBM library in scikit-learn, it's similar, but if I'm wrong, please correct me. Please check if your example now works and if the visualization meets your expectations. If you encounter any more issues in Supertree, feel free to report them, and I'll be happy to fix them.
Thanks! It does work now for categorical data with the LGBM, both in the example and in my original problem.
Good job @Marchlak :1st_place_medal:
When using categorical features that aren't binary, I experience an error in the show_tree().
As far as I can see, when trying to debug, it seems that the error is caused because splits on categorical features have more than one threshold, causing the threshold not to be a float variable.