Closed MatthewCWeston closed 7 months ago
Hi,
https://github.com/RobustBench/robustbench/pull/175 should fix it, please let me know if it works for you.
Hi @fra31 , I am facing the same issue when downloading Linf models ONLY, even with the fix you posted. I have pulled the latest version of the repo. Running the following command (with any model name):
model = load_model(model_name=<model_name> threat_model='Linf')
gives the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/pvaishnavi/projects/robustbench/robustbench/utils.py", line 156, in load_model
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
File "/home/pvaishnavi/anaconda3/envs/robustbench/lib/python3.9/site-packages/torch/serialization.py", line 713, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/pvaishnavi/anaconda3/envs/robustbench/lib/python3.9/site-packages/torch/serialization.py", line 920, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '<'.
I noticed that the models are being downloaded from gdrive but the size of the checkpoint file is only 2.4KB. Clearly the correct file is not being downloaded. You assistance on the matter is greatly appreciated.
Thank you!
Hi,
have you installed the branch with the updates, i.e. pip install git+https://github.com/RobustBench/robustbench.git@fix-download
? You'll also have to delete the old corrupted checkpoints.
I can confirm that the fix worked as it ought to - my code now runs fine without my workaround.
@pratik18v , did you try my workaround? If you're still getting the error, but my own fix doesn't solve it, then it'd be a new issue. If manually loading and re-saving the model as above does fix it, then the issue is likely a cached checkpoint or the update not having been installed.
Hi,
@fra31 I just cloned the latest version of the repo (git clone https://github.com/RobustBench/robustbench.git
) instead of installing through pip as I wanted to make invasive changes in the repository. I assumed that the fix would be included in it. I'll try the specific version of the repo that you have mentioned.
@MatthewCWeston Thank you so much for sharing your workaround of manually downloading checkpoints. It works perfectly for me :)
Please consider this issue resolved at my end too, thank you so much all!
The GDrive model for CIFAR10/Linf is now a tar archive, which breaks the load model function. As an interrim solution, the following code will fix the issue:
The below code raises an error when run without the model in folder, but will work once the above is run: