liangyanshuo / InfLoRA

The official implementation of the CVPR'2024 work Interference-Free Low-Rank Adaptation for Continual Learning
MIT License
30 stars 0 forks source link

How to use self-supervised weights #1

Open Ghy0501 opened 3 months ago

Ghy0501 commented 3 months ago

Hi, amazingwork!

I noticed that in the paper you use self-supervised weights, but I ran into problems when trying to reproduce the results. specifically, I directly change 'vit_base_patch16_224_in21k' to 'vit_base_patch16_224_dino' in models/sinet_inflora.py lines 68, then an error occured: 2024-04-10 12:56:43,013 [helpers.py] => Loading pretrained weights from url (https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth)
Traceback (most recent call last):
File "/root/InfLoRA/main.py", line 33, in
main()
File "/root/InfLoRA/main.py", line 11, in main
train(args)
File "/root/InfLoRA/trainer.py", line 21, in train
_train(args)
File "/root/InfLoRA/trainer.py", line 55, in _train
model = factory.get_model(args['model_name'], args)
File "/root/InfLoRA/utils/factory.py", line 19, in get_model
return optionsname
File "/root/InfLoRA/methods/inflora.py", line 27, in init
self._network = SiNet(args)
File "/root/InfLoRA/models/sinet_inflora.py", line 68, in init
self.image_encoder =_create_vision_transformer('vit_base_patch16_224_dino', pretrained=True, **model_kwargs) File "/root/InfLoRA/models/sinet_inflora.py", line 51, in _create_vision_transformer model = build_model_with_cfg(
File "/root/miniconda3/envs/pyt3_9/lib/python3.9/site-packages/timm/models/helpers.py", line 545, in build_model_with_cfg load_pretrained(
File "/root/miniconda3/envs/pyt3_9/lib/python3.9/site-packages/timm/models/helpers.py", line 296, in load_pretrained model.load_state_dict(state_dict, strict=strict)
File "/root/miniconda3/envs/pyt3_9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ViT_lora_co:
Missing key(s) in state_dict: "cls_token_grow", "pos_embed_grow", "blocks.0.attn.lora_A_k.0.weight", ......

How can I cope with this error and use other self-supervised weights e.g., deit and so on ?

I'm looking forward for you reply !

liangyanshuo commented 2 months ago

Hi,

The error occurred because some of the weights in the model, such as the LORA weights, are not included in the loaded self-pretrained weights. To resolve this issue, you can set "pretrained_strict=False" in the model_kwargs before loading the pre-trained weights into the model. Specifically, in models/sinet_inflora.py line 67, you can modify the code as follows: model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"], pretrained_strict=False)

I will update the code to support the self-pretrained weights in a few days. Thanks for reminding me!

Ghy0501 commented 2 months ago

Thank you for your reply, my problem has been solved, good luck with your research!