mwydmuch / napkinXC

Extremely simple and fast extreme multi-class and multi-label classifiers.
MIT License
64 stars 7 forks source link

pickling models #22

Open xiaohan2012 opened 3 years ago

xiaohan2012 commented 3 years ago

Hi,

I'm trying to use napkinXC under ray, which relies pickle for data serialization.

It seems that napkinXC models cannot be pickled.

For instance,

import pickle
from napkinxc.models import PLT

model = PLT('/tmp/something/')
pickle.dump(model, open('/tmp/some-pickle.pkl', 'wb'))

gives:

TypeError: cannot pickle 'napkinxc._napkinxc.CPPModel' object.

Is there any workaround or any plan to support pickling for this issue?

xiaohan2012 commented 3 years ago

A quick update: a workaround below seems to work.

import pickle
from napkinxc.models import PLT
from napkinxc.datasets import load_dataset
from napkinxc.measures import precision_at_k

class PickleablePLT(PLT):
    """a picklable PLT class"""
    def __getstate__(self):
        return self.get_params()

    def __setstate__(self, params):
        self.__init__(**params)
        self.load()

trn_X, trn_Y = load_dataset('eurlex-4k', "train", verbose=1)
tst_X, tst_Y = load_dataset('eurlex-4k', "test", verbose=1)

model = PickleablePLT('/tmp/someplt')
model.fit(trn_X, trn_Y )

pickle.dump(model, open('./myplt.pkl', 'wb'))

model_p = pickle.load(open('./myplt.pkl', 'rb'))

preds = model_p.predict_proba(tst_X, 5)

print(precision_at_k(tst_Y, preds))