MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
916 stars 181 forks source link

Can't load several pretrained models. #71

Closed qimingyudaowenti closed 4 years ago

qimingyudaowenti commented 4 years ago

Below error happens when I load pretrained model (ResNet50, CIFAR10 L2-norm: ε = 0.25 or ε = 1.0):

Traceback (most recent call last):
  File "/home/geyao/robust_analyse/test_AR_lib.py", line 32, in <module>
    final_model = main(args, store=store)
  File "/home/geyao/.local/lib/python3.8/site-packages/robustness/main.py", line 49, in main
    model, checkpoint = make_and_restore_model(arch=args.arch,
  File "/home/geyao/.local/lib/python3.8/site-packages/robustness/model_utils.py", line 93, in make_and_restore_model
    checkpoint = ch.load(resume_path, pickle_module=dill)
  File "/home/geyao/.conda/envs/gy_env/lib/python3.8/site-packages/torch/serialization.py", line 585, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/geyao/.conda/envs/gy_env/lib/python3.8/site-packages/torch/serialization.py", line 765, in _legacy_load
    result = unpickler.load()
  File "/home/geyao/.local/lib/python3.8/site-packages/dill/_dill.py", line 473, in load
    obj = StockUnpickler.load(self)
TypeError: an integer is required (got type bytes)

It's ok when I load other models trained by CIFAR10. Do you have any idea?

andrewilyas commented 4 years ago

What version of PyTorch and dill are you using? (try running pip freeze | grep dill and pip freeze | grep torch and paste the output).

qimingyudaowenti commented 4 years ago
dill==0.3.2
torch==1.6.0
torchvision==0.7.0

Do the different versions cause this problem?

Hadisalman commented 4 years ago

@qimingyudaowenti I have just tried loading these models using the versions of the libs that you have. It seems to work fine for me.

Here is a code that loads and evaluates these models. Redownload the checkpoints I would say too.

wget -O cifar_l2_1_0.pt https://www.dropbox.com/s/s2x7thisiqxz095/cifar_l2_1_0.pt?dl=0 &&
wget -O cifar_l2_0_25.pt https://www.dropbox.com/s/s2x7thisiqxz095/cifar_l2_0_25.pt?dl=0
from robustness import model_utils, datasets
import torch as ch
from tqdm.auto import tqdm 

ds = datasets.CIFAR('/tmp/')

# model_path = './cifar_l2_0_25.pt'
model_path = './cifar_l2_1_0.pt'

model, checkpoint = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path=model_path)
train_loader, val_loader = ds.make_loaders(batch_size=64, workers=4)

correct = 0
model.eval()
with ch.no_grad():
    for X,y in tqdm(val_loader):
        X,y = X.cuda(), y.cuda()
        out = model(X, with_image=False)
        _, pred = out.topk(1,1)
        correct += (pred.squeeze()==y).detach().cpu().sum()
print(f'The clean accuracy is {1.*correct/len(val_loader.dataset)*100.}%')

Let me know if this works for you.

qimingyudaowenti commented 4 years ago

Unluckily, the same error still happened. I close this issue because the problem has no severe effect on using the robustness library and is hard to reproduce.

Thanks for your help!!!

vietvo89 commented 2 years ago

@Hadisalman, I followed your suggestion, but it did not help. I also followed another suggestion from @andrewilyas to install dill 0.3.2 but it did not help too. I wonder why model l2 0.5 and l_inf_8 or nat work fine for me while model l2_0.25 or 1.0 cannot be loaded! It is critical to me, so please reply soon, thanks

Hadisalman commented 2 years ago

@vietvo89 I just tried loading these checkpoints with

python 3.7 
dill==0.3.4
torch==1.6.0

Please use python 3.7. It seems these models don't load well indeed in 3.8

Hope this helps!

vietvo89 commented 2 years ago

@Hadisalman , Thanks. Yes, it works with python 3.7. I just realize that I can use 3.7 and save it for 3.8. Then 3.8 can use it as @andrewilyas showed in another issue #.