RulinShao / on-the-adversarial-robustness-of-visual-transformer

Code for the paper "On the Adversarial Robustness of Visual Transformers"
55 stars 16 forks source link

The use of pretrain weights in adversarial training #1

Closed caposerenity closed 3 years ago

caposerenity commented 3 years ago

Hi @RulinShao , thanks a lot, for your impressive work. I'm trying to reproduce some of your experiments, I notice that in this repo you use timm for the pretrained version of vit-16. When I try to do the training process like you did in train.py, it seems that I fail with loading the pretrained weights of vit_base_patch16_224_in21k. The details of error : RuntimeError: Expected hasRecord("version") to be true, but got false.

env: torch==1.9.0 torchvision==0.10.0 running on Colab

I also tried older versions of torch, they all failed with loading pretrain weights with other error information like" Only one file(not dir) is allowed in the zipfile", it seems that the problem is about the compressed weights in zip format used in timm

Could you please tell me the environment you use when implementing this work? thanks a lot!

caposerenity commented 3 years ago

Here is the error detail under env:torch==1.9.0 RuntimeError Traceback (most recent call last) in () 27 model = eval(args.model)( 28 pretrained=(not args.scratch), ---> 29 img_size=crop_size, num_classes=num_classes, patch_size=args.patch, args=args).cuda() 30 logger.info('Model {}'.format(model)) 31 else:

5 frames /content/drive/My Drive/Colab Notebooks/on-the-adversarial-robustness-of-visual-transformer-master/training/timm_vit/vit.py in vit_base_patch16_224_in21k(pretrained, patch_size, args, kwargs) 336 model_kwargs = dict( 337 patch_size=patch_size, embed_dim=768, depth=depth, num_heads=12, representation_size=768, kwargs) --> 338 model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, args=args, **model_kwargs) 339 return model 340

/content/drive/My Drive/Colab Notebooks/on-the-adversarial-robustness-of-visual-transformer-master/training/timm_vit/vit.py in _create_vision_transformer(variant, pretrained, distilled, **kwargs) 272 load_pretrained( 273 model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), --> 274 filter_fn=partial(checkpoint_filter_fn, args=kwargs.pop('args'), model=model)) 275 else: 276 _logger.warning('Training from scratch')

/usr/local/lib/python3.7/dist-packages/timm/models/helpers.py in load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn, strict, progress) 182 else: 183 _logger.info(f'Loading pretrained weights from url ({pretrained_url})') --> 184 state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') 185 if filter_fn is not None: 186 # for backwards compat with filter fn that take one arg, try one first, the two

/usr/local/lib/python3.7/dist-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 556 if _is_legacy_zip_format(cached_file): 557 return _legacy_zip_load(cached_file, model_dir, map_location) --> 558 return torch.load(cached_file, map_location=map_location)

/usr/local/lib/python3.7/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args) 598 # reset back to the original position. 599 orig_position = opened_file.tell() --> 600 with _open_zipfile_reader(opened_file) as opened_zipfile: 601 if _is_torchscript_zip(opened_zipfile): 602 warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"

/usr/local/lib/python3.7/dist-packages/torch/serialization.py in init(self, name_or_buffer) 240 class _open_zipfile_reader(_opener): 241 def init(self, name_or_buffer) -> None: --> 242 super(_open_zipfile_reader, self).init(torch._C.PyTorchFileReader(name_or_buffer)) 243 244

RuntimeError: Expected hasRecord("version") to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

shizhouxing commented 3 years ago

Hi @caposerenity , what is the version of timm you are using? It seems that the latest version does not work here. Could you please try installingtimm==0.4.5?

caposerenity commented 3 years ago

Hi @caposerenity , what is the version of timm you are using? It seems that the latest version does not work here. Could you please try installingtimm==0.4.5?

Thank you, it really helps!