TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

How to load a saved ensemble #66

Closed mpierrau closed 3 years ago

mpierrau commented 3 years ago

Hi!

Thank you for a nice ensemble wrapper!

I can't seem to figure out how to load a saved model and there are no instructions in the docs on this process. I have tried to use load_state_dict with the loaded state dictionary, but I get a RuntimeError (missing keys). So performing:

kwargs = {'pretrained':False, 'num_classes' = 9}

model = VotingClassifier(models.mobilenet_v2(kwargs),5,True)
model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4)
model.fit(train_dl, epochs=10)

and then

d = torch.load('VotingClassifier_MobileNetV2_5_ckpt.pth')
model2 = VotingClassifier(models.mobilenet_v2(kwargs),5,True)
model2.load_state_dict(d)

gives the error

RuntimeError: Error(s) in loading state_dict for MobileNetV2:
    Missing key(s) in state_dict: "features.0.0.weight", "features.0.1.weight", "features.0.1.bias", "features.0.1.running_mean", "features.0.1.running_var", "features.1.conv.0.0.weight", "features.1.conv.0.1.weight", "features.1.conv.0.1.bias", "features.1.conv.0.1.running_mean", "features.1.conv.0.1.running_var", "features.1.conv.1.weight", "features.1.conv.2.weight", "features.1.conv.2.bias", "features.1.conv.2.running_mean", "features.1.conv.2.running_var", "features.2.conv.0.0.weight", "features.2.conv.0.1.weight", "features.2.conv.0.1.bias", "features.2.conv.0.1.running_mean", "features.2.conv.0.1.running_var", "features.2.conv.1.0.weight", ...

etc.

I have tried to load the dict to the individual base learners as well as loading d['model'], but without success.

Hoping you can help!

Thank you!

xuyxu commented 3 years ago

Hi @mpierrau, the test_clf_class function in test_all_models.py shows an example on how to save and re-load ensmebles.

Feel free to tell me if you have any problem.

EDIT: Thanks for reporting, we will add more explanations in the documentation.

mpierrau commented 3 years ago

Hi @xuyxu!

Thank you for your prompt reply and thank you for the pointer. I managed to load a saved model using the io.load module from torchensemble.utils.

Since this is a highly utilized function, I think your documentation would benefit from including this.

Thanks for the help! :)

xuyxu commented 3 years ago

Sure, we will add this. Glad to hear that your problem solved.