Closed Xiaohui9607 closed 3 years ago
Hi @Xiaohui9607, the latest version with PR #41 adds this. You can install the package from source:
git clone https://github.com/xuyxu/Ensemble-Pytorch.git
cd Ensemble-Pytorch
pip install .
no, it doesn't, you need to "fit" a model before you can load the checkpoint: fail:
model = VotingClassifier(estimator=model_wrapper, # your deep learning model
n_estimators=10)
# {Ensemble_Method_Name}_{Base_Estimator_Name}_{n_estimators}
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
model.base_estimator_.__name__,
model.n_estimators)
io.load(model, "/home/golf/Downloads")
success:
model = VotingClassifier(estimator=model_wrapper, # your deep learning model
n_estimators=10)
import os
# {Ensemble_Method_Name}_{Base_Estimator_Name}_{n_estimators}
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
model.base_estimator_.__name__,
model.n_estimators)
model.set_optimizer("Adam",
lr=learning_rate,
weight_decay=weight_decay)
model.fit(train_loader,
epochs=1)
io.load(model, "/home/golf/Downloads")
Thanks, I will fix this ;-)
If you want to use the model on the fly, could you check if methods like pickle.dump
help?
@all-contributors please add @Xiaohui9607 for bug
@xuyxu
I've put up a pull request to add @Xiaohui9607! :tada:
The problem is that torchensemble
does not register all base estimators into the ensemble during the instantiation. Instead, each of them will be inserted into the ensemble when calling the fit
method. As a result, if we want to reload params into a newly-declared model, all keys in the state_dict
are missing.
Thanks for your report, we will fix this.
RT