TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

Can't load model #43

Closed Xiaohui9607 closed 3 years ago

Xiaohui9607 commented 3 years ago

RT

xuyxu commented 3 years ago

Hi @Xiaohui9607, the latest version with PR #41 adds this. You can install the package from source:

git clone https://github.com/xuyxu/Ensemble-Pytorch.git
cd Ensemble-Pytorch
pip install .
Xiaohui9607 commented 3 years ago

no, it doesn't, you need to "fit" a model before you can load the checkpoint: fail:

model = VotingClassifier(estimator=model_wrapper,  # your deep learning model
                        n_estimators=10)
# {Ensemble_Method_Name}_{Base_Estimator_Name}_{n_estimators}
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
                                      model.base_estimator_.__name__,
                                      model.n_estimators)
io.load(model, "/home/golf/Downloads")

success:

model = VotingClassifier(estimator=model_wrapper,  # your deep learning model
                        n_estimators=10)
import os
# {Ensemble_Method_Name}_{Base_Estimator_Name}_{n_estimators}
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
                                      model.base_estimator_.__name__,
                                      model.n_estimators)
model.set_optimizer("Adam",
                    lr=learning_rate,
                    weight_decay=weight_decay)
model.fit(train_loader,
          epochs=1)
io.load(model, "/home/golf/Downloads")
xuyxu commented 3 years ago

Thanks, I will fix this ;-)

If you want to use the model on the fly, could you check if methods like pickle.dump help?

xuyxu commented 3 years ago

@all-contributors please add @Xiaohui9607 for bug

allcontributors[bot] commented 3 years ago

@xuyxu

I've put up a pull request to add @Xiaohui9607! :tada:

xuyxu commented 3 years ago

The problem is that torchensemble does not register all base estimators into the ensemble during the instantiation. Instead, each of them will be inserted into the ensemble when calling the fit method. As a result, if we want to reload params into a newly-declared model, all keys in the state_dict are missing.

Thanks for your report, we will fix this.