MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
903 stars 181 forks source link

Loading pretrained model in Python 3.8 fails #81

Closed rsokl closed 3 years ago

rsokl commented 3 years ago

PyTorch version info:

pytorch                   1.6.0           py3.8_cuda10.2.89_cudnn7.6.5_0    pytorch
torchvision               0.7.0                py38_cu102    pytorch

The following code works in an environment with Python 3.7, but does not work for an equivalent environment using Python 3.8

from robustness import model_utils, datasets
ds = datasets.CIFAR('/tmp/')
model_path = "cifar_l2_1_0.pt"
model, checkpoint = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path=model_path)
=> loading checkpoint 'cifar_l2_1_0.pt'

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-7793053c474f> in <module>
      2 
      3 model_path = "cifar_l2_1_0.pt"
----> 4 model, checkpoint = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path=model_path)

~/anaconda3/envs/raiden/lib/python3.8/site-packages/robustness/model_utils.py in make_and_restore_model(arch, dataset, resume_path, parallel, pytorch_pretrained, add_custom_forward, *_)
     91     if resume_path and os.path.isfile(resume_path):
     92         print("=> loading checkpoint '{}'".format(resume_path))
---> 93         checkpoint = ch.load(resume_path, pickle_module=dill)
     94 
     95         # Makes us able to load models saved with legacy versions

~/anaconda3/envs/raiden/lib/python3.8/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    583                     return torch.jit.load(opened_file)
    584                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 585         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    586 
    587 

~/anaconda3/envs/raiden/lib/python3.8/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
    763     unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    764     unpickler.persistent_load = persistent_load
--> 765     result = unpickler.load()
    766 
    767     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)

~/anaconda3/envs/raiden/lib/python3.8/site-packages/dill/_dill.py in load(self)
    471 
    472     def load(self): #NOTE: if settings change, need to update attributes
--> 473         obj = StockUnpickler.load(self)
    474         if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
    475             if not self._ignore:

TypeError: an integer is required (got type bytes)
andrewilyas commented 3 years ago

Hi! Can you please indicate what version of dill is in each environment?

rsokl commented 3 years ago

Sure! Both versions are using

dill                      0.3.2                    pypi_0    pypi
andrewilyas commented 3 years ago

Sorry forgot to follow up on this! Can you also post your torch version on the 3.7 environment? Is it the same?

rsokl commented 3 years ago

It is the same - 1.6.0

Both environments have the same library versions. Only the Python version is different

andrewilyas commented 3 years ago

Ok looks like this is a known issue with dill (https://github.com/uqfoundation/dill/issues/357), but it was patched in the latest version. If you have both 3.7 and 3.8 on your machine, you can fix this yourself by doing:

  1. Update dill to 0.3.2 in both versions
  2. In Python 3.7, execute:
    import torch, dill
    x = torch.load("cifar_l2_1_0.pt", pickle_module=dill)
    torch.save(x, "cifar_l2_1_0.pt", pickle_module=dill)
  3. In Python 3.8, you should now be able to load the checkpoint using your attached code.

Meanwhile, we will do this for all the links we have above so that the public versions are also cross-compatible.

rsokl commented 3 years ago

Great, I will give that a shot. Thank you for your help!

Closing this assuming that the above route will work without a hitch :)

vietvo89 commented 2 years ago

Hello @andrewilyas and @rsokl

I got the same issue when loading model l2 0.25 and 1.0. But loading model l2 0.5 and l_inf_8 has no problem. I followed your recommendation to downgrade dill 0.3.4 to 0.3.2, but it did not work. I am using python 3.8. Please do me a favor!

Thanks