Open xiaohan2012 opened 3 years ago
A quick update: a workaround below seems to work.
import pickle
from napkinxc.models import PLT
from napkinxc.datasets import load_dataset
from napkinxc.measures import precision_at_k
class PickleablePLT(PLT):
"""a picklable PLT class"""
def __getstate__(self):
return self.get_params()
def __setstate__(self, params):
self.__init__(**params)
self.load()
trn_X, trn_Y = load_dataset('eurlex-4k', "train", verbose=1)
tst_X, tst_Y = load_dataset('eurlex-4k', "test", verbose=1)
model = PickleablePLT('/tmp/someplt')
model.fit(trn_X, trn_Y )
pickle.dump(model, open('./myplt.pkl', 'wb'))
model_p = pickle.load(open('./myplt.pkl', 'rb'))
preds = model_p.predict_proba(tst_X, 5)
print(precision_at_k(tst_Y, preds))
Hi,
I'm trying to use napkinXC under ray, which relies
pickle
for data serialization.It seems that napkinXC models cannot be pickled.
For instance,
gives:
Is there any workaround or any plan to support pickling for this issue?