MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
903 stars 181 forks source link

Pre-Trained Model #87

Closed sdan2 closed 3 years ago

sdan2 commented 3 years ago

How do I exactly load a pre-trained model? I have the downloaded cifar_linf_8.pt file. Then I do the following:

from robustness.cifar_models.resnet import ResNet50 model=ResNet50() model.load_state_dict(torch.load("cifar_nat.pt"))

However, I am getting an error saying the keys do not match. Do you have any suggestions on what to do? Thanks!

andrewilyas commented 3 years ago

Yes, see the robustness.model_utils.make_and_restore_model function, that should help!

sdan2 commented 3 years ago

Thanks, that worked!