JH-LEE-KR / l2p-pytorch

PyTorch Implementation of Learning to Prompt (L2P) for Continual Learning @ CVPR22
Apache License 2.0
176 stars 22 forks source link

vit_base_patch16_224 #10

Closed Geeks-Z closed 1 year ago

Geeks-Z commented 1 year ago

Hello! Thank you so much for implementing the pytorch version of l2p! When obtaining the pre-training model, I encountered the following problems. Are there any other friends who encounter similar problems? Snipaste_2023-04-24_21-29-49 Thank you for your work! Looking forward to your reply, best wishes!

JessicaGao0527 commented 1 year ago

I try to initialize the vit_base_patch16_224 model in my jupyter notebook. The timm library can't find the registered model which is in file 'vision_transformer.py'. You can try something like this:

`from vision_transformer import vit_base_patch16_224

model = create_model( args.model, pretrained=args.pretrained, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, prompt_length=args.length, embedding_key=args.embedding_key, prompt_init=args.prompt_key_init, prompt_pool=args.prompt_pool, prompt_key=args.prompt_key, pool_size=args.size, top_k=args.top_k, batchwise_prompt=args.batchwise_prompt, prompt_key_init=args.prompt_key_init, head_type=args.head_type, use_prompt_mask=args.use_prompt_mask, )`

Geeks-Z commented 1 year ago

I try to initialize the vit_base_patch16_224 model in my jupyter notebook. The timm library can't find the registered model which is in file 'vision_transformer.py'. You can try something like this:

`from vision_transformer import vit_base_patch16_224

model = create_model( args.model, pretrained=args.pretrained, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, prompt_length=args.length, embedding_key=args.embedding_key, prompt_init=args.prompt_key_init, prompt_pool=args.prompt_pool, prompt_key=args.prompt_key, pool_size=args.size, top_k=args.top_k, batchwise_prompt=args.batchwise_prompt, prompt_key_init=args.prompt_key_init, head_type=args.head_type, use_prompt_mask=args.use_prompt_mask, )`

Thank you

Geeks-Z commented 1 year ago

The final code is as follows:

from vision_transformer import vit_base_patch16_224 print(f"Creating original model: {args.model}") original_model = create_model( args.model, pretrained=args.pretrained, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=args.pretrained, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, prompt_length=args.length, embedding_key=args.embedding_key, prompt_init=args.prompt_key_init, prompt_pool=args.prompt_pool, prompt_key=args.prompt_key, pool_size=args.size, top_k=args.top_k, batchwise_prompt=args.batchwise_prompt, prompt_key_init=args.prompt_key_init, head_type=args.head_type, use_prompt_mask=args.use_prompt_mask, ) original_model.to(device)