cair / tmu

Implements the Tsetlin Machine, Coalesced Tsetlin Machine, Convolutional Tsetlin Machine, Regression Tsetlin Machine, and Weighted Tsetlin Machine, with support for continuous features, drop clause, Type III Feedback, focused negative sampling, multi-task classifier, autoencoder, literal budget, and one-vs-one multi-class classifier. TMU is written in Python with wrappers for C and CUDA-based clause evaluation and updating.
https://pypi.org/project/tmu/
MIT License
129 stars 14 forks source link

Saving and loading trained classifiers #43

Closed anayurg closed 1 year ago

anayurg commented 1 year ago

Hey there,

I'm playing around with the Tsetlin classifier and I'm having some trouble figuring out how to save a trained classifier so that it can be loaded later. Do you have a go-to method that you can recommend? I couldn't do it with pickle and joblib. So I tried pickling the clause banks and weight banks separately but that seems to work only if I pickle them before the classifier is used for any prediction which in my case is not an option.

I'm attaching a minimal example just in case but I imagine there must be another way to do it that I haven't figured out.

saving.py

import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data import TMUDatasetSource

data = TMUDatasetSource().get_dataset(
        "XOR_biased",
        cache=True,
        cache_max_age=1,
        features=["X", "Y"],
        labels=["xor"],
        shuffle=True,
        train_ratio=1000,
        test_ratio=1000,
        return_type=dict
    )

tm = TMClassifier(20,10,2)
tm.fit(data["x_train"], data["y_train"])

# This method works only if these two lines are commented out:
# y_pred = tm.predict(data["x_test"])
# print((y_pred == data["y_test"]).mean())

for i in range(len(tm.weight_banks)):
    with open('wb%s.pkl' % i, 'wb') as f:
        pickle.dump(tm.weight_banks[i], f)
for i in range(len(tm.clause_banks)):
    with open('cb%s.pkl' % i, 'wb') as f:
        pickle.dump(tm.clause_banks[i], f)

loading.py

from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data import TMUDatasetSource
import os
import pickle

data = TMUDatasetSource().get_dataset(
        "XOR_biased",
        cache=True,
        cache_max_age=1,
        features=["X", "Y"],
        labels=["xor"],
        shuffle=True,
        train_ratio=1000,
        test_ratio=1000,
        return_type=dict
    )

tm_new = TMClassifier(20,10,2)
tm_new.clause_banks = []
tm_new.number_of_classes = 2
tm_new.weight_banks = []

files = os.listdir('/home/folder/')
cb_count = sum(1 for file in files if file.startswith('cb'))
wb_count = sum(1 for file in files if file.startswith('wb'))

for i in range(wb_count):
    with open('wb%s.pkl' % i, 'rb') as f:
        wb = pickle.load(f)
        tm_new.weight_banks.append(wb)
for i in range(cb_count):
    with open('cb%s.pkl' % i, 'rb') as f:
        cb = pickle.load(f)
        tm_new.clause_banks.append(cb)

y_pred_new = tm_new.predict(data["x_test"])
print((y_pred_new == data["y_test"]).mean())

The error that I get when I first predict something and then save the clause and weight banks is the following:

Traceback (most recent call last):
  File "loading.py", line 36, in <module>
    y_pred_new = tm_new.predict(data["x_test"])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/envs/tmtest/lib/python3.11/site-packages/tmu/models/classification/vanilla_classifier.py", line 311, in predict
    self.clause_banks[i].calculate_clause_outputs_predict(self.encoded_X_test,
  File "/home/envs/tmtest/lib/python3.11/site-packages/tmu/clause_bank/clause_bank.py", line 168, in calculate_clause_outputs_predict
    self.lcm_p,
    ^^^^^^^^^^
AttributeError: 'ClauseBank' object has no attribute 'lcm_p'. Did you mean: 'lcc_p'?

Thanks for the help!

satunheim commented 1 year ago

Please also see Issue #46 which is related to this.

perara commented 1 year ago

Dear @anayurg,

May I suggest just loading the whole TM directly?

saving:

import pickle
from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data.tmu_datasource import TMUDatasetSource

data = TMUDatasetSource().get_dataset(
    "XOR_biased",
    cache=True,
    cache_max_age=1,
    features=["X", "Y"],
    labels=["xor"],
    shuffle=True,
    train_ratio=1000,
    test_ratio=1000,
    return_type=dict
)

tm = TMClassifier(20,10,2)
tm.fit(data["x_train"], data["y_train"])

with open('tm.pkl', 'wb') as f:
    pickle.dump(tm, f)

loading:

from tmu.models.classification.vanilla_classifier import TMClassifier
from tmu.data.tmu_datasource import TMUDatasetSource
import os
import pickle

data = TMUDatasetSource().get_dataset(
    "XOR_biased",
    cache=True,
    cache_max_age=1,
    features=["X", "Y"],
    labels=["xor"],
    shuffle=True,
    train_ratio=1000,
    test_ratio=1000,
    return_type=dict
)

with open('tm.pkl', 'rb') as f:
    tm_new = pickle.load(f)

y_pred_new = tm_new.predict(data["x_test"])
print((y_pred_new == data["y_test"]).mean())
perara commented 1 year ago

I've also corrected AttributeError: 'ClauseBank' object has no attribute 'lcm_p'. Did you mean: 'lcc_p'?

Thanks for finding this :)

Reopen if this does not satisfy your requirements.