aksnzhy / xlearn

High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.
https://xlearn-doc.readthedocs.io/en/latest/index.html
Apache License 2.0
3.08k stars 518 forks source link

How to save/load scikit-learn API model #123

Open ethen8181 opened 6 years ago

ethen8181 commented 6 years ago

Hi, team. What's the best approach right now to save and load a scikit-learn like model? Pickling doesn't seem to work. Thanks!

# the example from the tutorial
import numpy as np
import xlearn as xl
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load dataset
iris_data = load_iris()
X = iris_data['data']
y = (iris_data['target'] == 2)

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.3, random_state=0)

# param:
#  0. binary classification
#  1. model scale: 0.1
#  2. epoch number: 10 (auto early-stop)
#  3. learning rate: 0.1
#  4. regular lambda: 1.0
#  5. use sgd optimization method
linear_model = xl.LRModel(task='binary', init=0.1,
                          epoch=10, lr=0.1,
                          reg_lambda=1.0, opt='sgd')

# Start to train
linear_model.fit(X_train, y_train,
                 eval_set=[X_val, y_val],
                 is_lock_free=False)

# attempting to save the model
from joblib import dump
dump(linear_model, 'temp.pkl')
# ValueError: ctypes objects containing pointers cannot be pickled
aksnzhy commented 6 years ago

@randxie Can you check out this issue?

randxie commented 6 years ago

@aksnzhy It should be the issue of _XLearnModel that can not be pickled. We could either add getstate and setstate methods to XLearn class, or provide save_model and load_model methods to the sklearn interface. What do you think?

Superhzf commented 5 years ago

I have the same problem, is there any progress on this topic?