Closed hcho3 closed 4 months ago
Reproducer:
import numpy as np import treelite from sklearn.datasets import make_classification from sklearn.ensemble import RandomForestClassifier from sklearn.utils import shuffle X, y1 = make_classification( n_samples=10, n_features=100, n_informative=30, n_classes=3, random_state=1 ) y2 = shuffle(y1, random_state=1) y3 = shuffle(y1, random_state=2) Y = np.vstack((y1, y2, y3)).T print(X.shape, Y.shape) n_samples, n_features = X.shape n_outputs = Y.shape[1] n_classes = 3 forest = RandomForestClassifier(random_state=1) forest.fit(X, Y) model = treelite.sklearn.import_model(forest)
Error:
File "treelite/sklearn/importer.py", line 46, in add array.shape == expected_shape AssertionError: Expected shape: (13, 3, [3, 3, 3]), Got shape (13, 3, 3)
Reproducer:
Error: