aerdem4 / lofo-importance

Leave One Feature Out Importance
MIT License
810 stars 83 forks source link

Support multiclass classification ? #40

Closed ybdesire closed 2 years ago

ybdesire commented 2 years ago

The code below is okay to get importance_df.

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer, load_iris
from sklearn.model_selection import KFold
from lofo import LOFOImportance, Dataset, plot_importance

data = load_breast_cancer(as_frame=True)# load as dataframe
df = data.data
df['target']=data.target.values

# model
model = RandomForestClassifier()
# dataset
dataset = Dataset(df=df, target="target", features=[col for col in df.columns if col != 'target'])
# get feature importance
cv = KFold(n_splits=5, shuffle=True, random_state=666)
lofo_imp = LOFOImportance(dataset, cv=cv, scoring="f1",model=model)
importance_df = lofo_imp.get_importance()
print(importance_df)

But if we modify load_breast_cancer to load_iris, the importance_df values are all NaN.

Is the lofo-importance only support binary classification?

ybdesire commented 2 years ago

FLOFO multiclass classification result is correct.

from lofo import FLOFOImportance
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer, load_iris
from sklearn.model_selection import KFold
from lofo import LOFOImportance, Dataset, plot_importance
# step-01: prepare data
data = load_iris(as_frame=True)# load as dataframe
x_data = data.data.to_numpy()
y_data = data.target.values
df = data.data
df['target']=data.target.values
# repeat more data since FLOFO need > 1000 data
df=pd.DataFrame(pd.np.repeat(df.values,10,axis=0),columns=df.columns)
# step-02: train model
model = RandomForestClassifier()
model.fit(x_data,y_data)
# step-03: fast-lofo
lofo_imp = FLOFOImportance(validation_df=df, target="target", features=[col for col in df.columns if col != 'target'],scoring="f1_macro",trained_model=model)
importance_df = lofo_imp.get_importance()
print(importance_df)
ybdesire commented 2 years ago

Modify scoring="f1" to scoring="f1_macro" fixed the issue. Since multiclass f1 value should calculated by f1_macro or f1_micro.