dmlc / treelite

Universal model exchange and serialization format for decision tree forests
https://treelite.readthedocs.io/en/latest/
Apache License 2.0
723 stars 98 forks source link

Multi-class, multi-output RandomForestClassifier in scikit-learn produces error #545

Closed hcho3 closed 4 months ago

hcho3 commented 5 months ago

Reproducer:

import numpy as np
import treelite
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle

X, y1 = make_classification(
    n_samples=10, n_features=100, n_informative=30, n_classes=3, random_state=1
)
y2 = shuffle(y1, random_state=1)
y3 = shuffle(y1, random_state=2)
Y = np.vstack((y1, y2, y3)).T
print(X.shape, Y.shape)
n_samples, n_features = X.shape
n_outputs = Y.shape[1]
n_classes = 3
forest = RandomForestClassifier(random_state=1)
forest.fit(X, Y)

model = treelite.sklearn.import_model(forest)

Error:

  File "treelite/sklearn/importer.py", line 46, in add
    array.shape == expected_shape
AssertionError: Expected shape: (13, 3, [3, 3, 3]), Got shape (13, 3, 3)