RobustBench / robustbench

RobustBench: a standardized adversarial robustness benchmark [NeurIPS 2021 Benchmarks and Datasets Track]
https://robustbench.github.io
Other
664 stars 99 forks source link

Error when loading CIFAR10 Linf model #174

Closed MatthewCWeston closed 7 months ago

MatthewCWeston commented 7 months ago

The GDrive model for CIFAR10/Linf is now a tar archive, which breaks the load model function. As an interrim solution, the following code will fix the issue:

!mkdir -p models/cifar10/Linf/
!gdown 1t98aEuzeTL8P7Kpd5DIrCoCL21BNZUhC -O models/cifar10/Linf/natural.pt.tar
import torch
m = torch.load('models/cifar10/Linf/natural.pt.tar')
torch.save(m, "models/cifar10/Linf/Standard.pt")

The below code raises an error when run without the model in folder, but will work once the above is run:

load_model(model_name='Standard', dataset='cifar10', threat_model='Linf')
fra31 commented 7 months ago

Hi,

https://github.com/RobustBench/robustbench/pull/175 should fix it, please let me know if it works for you.

pratik18v commented 7 months ago

Hi @fra31 , I am facing the same issue when downloading Linf models ONLY, even with the fix you posted. I have pulled the latest version of the repo. Running the following command (with any model name):

model = load_model(model_name=<model_name> threat_model='Linf')

gives the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/pvaishnavi/projects/robustbench/robustbench/utils.py", line 156, in load_model
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
  File "/home/pvaishnavi/anaconda3/envs/robustbench/lib/python3.9/site-packages/torch/serialization.py", line 713, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/pvaishnavi/anaconda3/envs/robustbench/lib/python3.9/site-packages/torch/serialization.py", line 920, in _legacy_load
    magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '<'.

I noticed that the models are being downloaded from gdrive but the size of the checkpoint file is only 2.4KB. Clearly the correct file is not being downloaded. You assistance on the matter is greatly appreciated.

Thank you!

fra31 commented 7 months ago

Hi,

have you installed the branch with the updates, i.e. pip install git+https://github.com/RobustBench/robustbench.git@fix-download? You'll also have to delete the old corrupted checkpoints.

MatthewCWeston commented 7 months ago

I can confirm that the fix worked as it ought to - my code now runs fine without my workaround.

@pratik18v , did you try my workaround? If you're still getting the error, but my own fix doesn't solve it, then it'd be a new issue. If manually loading and re-saving the model as above does fix it, then the issue is likely a cached checkpoint or the update not having been installed.

pratik18v commented 7 months ago

Hi, @fra31 I just cloned the latest version of the repo (git clone https://github.com/RobustBench/robustbench.git) instead of installing through pip as I wanted to make invasive changes in the repository. I assumed that the fix would be included in it. I'll try the specific version of the repo that you have mentioned.

@MatthewCWeston Thank you so much for sharing your workaround of manually downloading checkpoints. It works perfectly for me :)

Please consider this issue resolved at my end too, thank you so much all!