scikit-learn-contrib / hiclass

A python library for hierarchical classification compatible with scikit-learn
BSD 3-Clause "New" or "Revised" License
113 stars 20 forks source link

Problem dumping model with pickle #126

Open tomtaylor opened 2 months ago

tomtaylor commented 2 months ago

Describe the bug

I've just trained a classifier, but when I try and dump the model with pickle I get the following:

Traceback (most recent call last):
  File "/obfuscated/train.py", line 43, in <module>
    pickle.dump(classifier, open(filename, "wb"), protocol=5)
_pickle.PicklingError: Can't pickle <function check_array at 0x14df40540>: it's not the same object as sklearn.utils.validation.check_array

To Reproduce

I'm using scikit-learn==1.5.0, hiclass==4.10.0, on Python 3.11.1, with the following sample of code:

print("Preparing test/train split")
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.3, random_state=42
)

hg = HistGradientBoostingClassifier(
    random_state=42, verbose=1, categorical_features=[0]
)
classifier = LocalClassifierPerNode(local_classifier=hg)

print("Fitting classifier")
classifier.fit(x_train, y_train)

print("Saving model")
filename = "model.pkl"
pickle.dump(classifier, open(filename, "wb"), protocol=5)

Expected behavior

I'd expect the model to dump with pickle.

(PS: thanks a lot for this library!)

tomtaylor commented 2 months ago

Changing to categorical_features="from_dtype" seems to have fixed this!

mirand863 commented 2 months ago

Describe the bug

I've just trained a classifier, but when I try and dump the model with pickle I get the following:

Traceback (most recent call last):
  File "/obfuscated/train.py", line 43, in <module>
    pickle.dump(classifier, open(filename, "wb"), protocol=5)
_pickle.PicklingError: Can't pickle <function check_array at 0x14df40540>: it's not the same object as sklearn.utils.validation.check_array

To Reproduce

I'm using scikit-learn==1.5.0, hiclass==4.10.0, on Python 3.11.1, with the following sample of code:

print("Preparing test/train split")
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.3, random_state=42
)

hg = HistGradientBoostingClassifier(
    random_state=42, verbose=1, categorical_features=[0]
)
classifier = LocalClassifierPerNode(local_classifier=hg)

print("Fitting classifier")
classifier.fit(x_train, y_train)

print("Saving model")
filename = "model.pkl"
pickle.dump(classifier, open(filename, "wb"), protocol=5)

Expected behavior

I'd expect the model to dump with pickle.

(PS: thanks a lot for this library!)

Hi @tomtaylor,

Thank you for reporting this issue.

I have never come across this error before, but I have reasons to believe that it is caused by a conflict when importing libraries. Can you please send me the full code with imports included? It will be easier to figure out what is causing it if I can reproduce.