automl / TabPFN

Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package.
http://priorlabs.ai
Apache License 2.0
1.22k stars 109 forks source link

feat: TabPFNClassifier can be pickled #9

Closed Thytu closed 2 years ago

Thytu commented 2 years ago

As you were using lambdas in load_model, TabPFNClassifier couldn't be pickled. Those few changes allows TabPFNClassifier to be pickled.

Test example:

import os
import pickle

from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

classifier = TabPFNClassifier(device='cpu')

pickle.dump(classifier, open(os.path.join("/tmp/", "model.pickle"), 'wb+'), pickle.HIGHEST_PROTOCOL)

Before : AttributeError: Can't pickle local object 'load_model.<locals>.<lambda>'

After : Pickle

Thytu commented 2 years ago

Note : I really like what you did guys, great work!

I would be happy to help if you're interested 🙂

SamuelGabriel commented 2 years ago

Thank you very much for your contribution. I will look into it and merge afterwards (in the coming weeks). Sorry, we did not merge yet. We did not get to it.

Thytu commented 2 years ago

Thank you very much for your contribution. I will look into it and merge afterwards (in the coming weeks). Sorry, we did not merge yet. We did not get to it.

Happy to help

SamuelGabriel commented 2 years ago

We looked into this a little further and found that we could do some deeper changes together with the change you propose to minimize our dependencies. Thus, we now implemented our own version of you PR (which gets rid of the old loading function all together). I will close this PR, but very much appreciate it and am happy that we could implement the functionality that you missed! If you want to use it, it is in the current pip version already :)