siyi-wind / AViT

[MICCAI ISIC Workshop 2023 (best paper)] AViT: Adapting Vision Transformers for Small Skin Lesion Segmentation Datasets (an official implementation)
17 stars 2 forks source link

test model #1

Closed dysion closed 2 months ago

dysion commented 9 months ago

Hi! Thanks for your wonderful work! Can you share the trained AViT model weights, I just want to have a test for interest. Thank you!

siyi-wind commented 9 months ago

Sure. I have uploaded the weights for AViT based on ViT-Base.

dysion commented 9 months ago

Thank you!

dysion commented 8 months ago

Hi! I tried to use the published weights to do some tests, but encountered errors.

The bash command is : python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/

The test model is ViTSeg, the weights file is AViT_ViT_B_ISIC_best.pth. When model.load_state_dict(torch.load(model_dir)), errors are:

RuntimeError: Error(s) in loading state_dict for ViTSeg_adapt: Missing key(s) in state_dict: "final_conv.weight", "final_conv.bias". Unexpected key(s) in state_dict: "prompt_encoder.conv1.weight", "prompt_encoder.bn1.weight", "prompt_encoder.bn1.bias", "prompt_encoder.bn1.running_mean", "prompt_encoder.bn1.running_var", "prompt_encoder.bn1.num_batches_tracked", "prompt_encoder.layer1.0.conv1.weight", "prompt_encoder.layer1.0.bn1.weight", "prompt_encoder.layer1.0.bn1.bias", "prompt_encoder.layer1.0.bn1.running_mean", "prompt_encoder.layer1.0.bn1.running_var", "prompt_encoder.layer1.0.bn1.num_batches_tracked", "prompt_encoder.layer1.0.conv2.weight", "prompt_encoder.layer1.0.bn2.weight", "prompt_encoder.layer1.0.bn2.bias", "prompt_encoder.layer1.0.bn2.running_mean", "prompt_encoder.layer1.0.bn2.running_var", "prompt_encoder.layer1.0.bn2.num_batches_tracked", "prompt_encoder.layer1.1.conv1.weight", "prompt_encoder.layer1.1.bn1.weight", "prompt_encoder.layer1.1.bn1.bias", "prompt_encoder.layer1.1.bn1.running_mean", "prompt_encoder.layer1.1.bn1.running_var", "prompt_encoder.layer1.1.bn1.num_batches_tracked", "prompt_encoder.layer1.1.conv2.weight", "prompt_encoder.layer1.1.bn2.weight", "prompt_encoder.layer1.1.bn2.bias", "prompt_encoder.layer1.1.bn2.running_mean", "prompt_encoder.layer1.1.bn2.running_var", "prompt_encoder.layer1.1.bn2.num_batches_tracked", "prompt_encoder.layer1.2.conv1.weight", "prompt_encoder.layer1.2.bn1.weight", "prompt_encoder.layer1.2.bn1.bias", "prompt_encoder.layer1.2.bn1.running_mean", "prompt_encoder.layer1.2.bn1.running_var", "prompt_encoder.layer1.2.bn1.num_batches_tracked", "prompt_encoder.layer1.2.conv2.weight", "prompt_encoder.layer1.2.bn2.weight", "prompt_encoder.layer1.2.bn2.bias", "prompt_encoder.layer1.2.bn2.running_mean", "prompt_encoder.layer1.2.bn2.running_var", "prompt_encoder.layer1.2.bn2.num_batches_tracked", "encoder.norm.weight", "encoder.norm.bias", "encoder.blocks.0.adapter1.0.D_fc1.weight", "encoder.blocks.0.adapter1.0.D_fc1.bias", "encoder.blocks.0.adapter1.0.D_fc2.weight", "encoder.blocks.0.adapter1.0.D_fc2.bias", "encoder.blocks.0.adapter2.0.D_fc1.weight", "encoder.blocks.0.adapter2.0.D_fc1.bias", "encoder.blocks.0.adapter2.0.D_fc2.weight", "encoder.blocks.0.adapter2.0.D_fc2.bias", "encoder.blocks.1.adapter1.0.D_fc1.weight", "encoder.blocks.1.adapter1.0.D_fc1.bias", "encoder.blocks.1.adapter1.0.D_fc2.weight", "encoder.blocks.1.adapter1.0.D_fc2.bias", "encoder.blocks.1.adapter2.0.D_fc1.weight", "encoder.blocks.1.adapter2.0.D_fc1.bias", "encoder.blocks.1.adapter2.0.D_fc2.weight", "encoder.blocks.1.adapter2.0.D_fc2.bias", "encoder.blocks.2.adapter1.0.D_fc1.weight", "encoder.blocks.2.adapter1.0.D_fc1.bias", "encoder.blocks.2.adapter1.0.D_fc2.weight", "encoder.blocks.2.adapter1.0.D_fc2.bias", "encoder.blocks.2.adapter2.0.D_fc1.weight", "encoder.blocks.2.adapter2.0.D_fc1.bias", "encoder.blocks.2.adapter2.0.D_fc2.weight", "encoder.blocks.2.adapter2.0.D_fc2.bias", "encoder.blocks.3.adapter1.0.D_fc1.weight", "encoder.blocks.3.adapter1.0.D_fc1.bias", "encoder.blocks.3.adapter1.0.D_fc2.weight", "encoder.blocks.3.adapter1.0.D_fc2.bias", "encoder.blocks.3.adapter2.0.D_fc1.weight", "encoder.blocks.3.adapter2.0.D_fc1.bias", "encoder.blocks.3.adapter2.0.D_fc2.weight", "encoder.blocks.3.adapter2.0.D_fc2.bias", "encoder.blocks.4.adapter1.0.D_fc1.weight", "encoder.blocks.4.adapter1.0.D_fc1.bias", "encoder.blocks.4.adapter1.0.D_fc2.weight", "encoder.blocks.4.adapter1.0.D_fc2.bias", "encoder.blocks.4.adapter2.0.D_fc1.weight", "encoder.blocks.4.adapter2.0.D_fc1.bias", "encoder.blocks.4.adapter2.0.D_fc2.weight", "encoder.blocks.4.adapter2.0.D_fc2.bias", "encoder.blocks.5.adapter1.0.D_fc1.weight", "encoder.blocks.5.adapter1.0.D_fc1.bias", "encoder.blocks.5.adapter1.0.D_fc2.weight", "encoder.blocks.5.adapter1.0.D_fc2.bias", "encoder.blocks.5.adapter2.0.D_fc1.weight", "encoder.blocks.5.adapter2.0.D_fc1.bias", "encoder.blocks.5.adapter2.0.D_fc2.weight", "encoder.blocks.5.adapter2.0.D_fc2.bias", "encoder.blocks.6.adapter1.0.D_fc1.weight", "encoder.blocks.6.adapter1.0.D_fc1.bias", "encoder.blocks.6.adapter1.0.D_fc2.weight", "encoder.blocks.6.adapter1.0.D_fc2.bias", "encoder.blocks.6.adapter2.0.D_fc1.weight", "encoder.blocks.6.adapter2.0.D_fc1.bias", "encoder.blocks.6.adapter2.0.D_fc2.weight", "encoder.blocks.6.adapter2.0.D_fc2.bias", "encoder.blocks.7.adapter1.0.D_fc1.weight", "encoder.blocks.7.adapter1.0.D_fc1.bias", "encoder.blocks.7.adapter1.0.D_fc2.weight", "encoder.blocks.7.adapter1.0.D_fc2.bias", "encoder.blocks.7.adapter2.0.D_fc1.weight", "encoder.blocks.7.adapter2.0.D_fc1.bias", "encoder.blocks.7.adapter2.0.D_fc2.weight", "encoder.blocks.7.adapter2.0.D_fc2.bias", "encoder.blocks.8.adapter1.0.D_fc1.weight", "encoder.blocks.8.adapter1.0.D_fc1.bias", "encoder.blocks.8.adapter1.0.D_fc2.weight", "encoder.blocks.8.adapter1.0.D_fc2.bias", "encoder.blocks.8.adapter2.0.D_fc1.weight", "encoder.blocks.8.adapter2.0.D_fc1.bias", "encoder.blocks.8.adapter2.0.D_fc2.weight", "encoder.blocks.8.adapter2.0.D_fc2.bias", "encoder.blocks.9.adapter1.0.D_fc1.weight", "encoder.blocks.9.adapter1.0.D_fc1.bias", "encoder.blocks.9.adapter1.0.D_fc2.weight", "encoder.blocks.9.adapter1.0.D_fc2.bias", "encoder.blocks.9.adapter2.0.D_fc1.weight", "encoder.blocks.9.adapter2.0.D_fc1.bias", "encoder.blocks.9.adapter2.0.D_fc2.weight", "encoder.blocks.9.adapter2.0.D_fc2.bias", "encoder.blocks.10.adapter1.0.D_fc1.weight", "encoder.blocks.10.adapter1.0.D_fc1.bias", "encoder.blocks.10.adapter1.0.D_fc2.weight", "encoder.blocks.10.adapter1.0.D_fc2.bias", "encoder.blocks.10.adapter2.0.D_fc1.weight", "encoder.blocks.10.adapter2.0.D_fc1.bias", "encoder.blocks.10.adapter2.0.D_fc2.weight", "encoder.blocks.10.adapter2.0.D_fc2.bias", "encoder.blocks.11.adapter1.0.D_fc1.weight", "encoder.blocks.11.adapter1.0.D_fc1.bias", "encoder.blocks.11.adapter1.0.D_fc2.weight", "encoder.blocks.11.adapter1.0.D_fc2.bias", "encoder.blocks.11.adapter2.0.D_fc1.weight", "encoder.blocks.11.adapter2.0.D_fc1.bias", "encoder.blocks.11.adapter2.0.D_fc2.weight", "encoder.blocks.11.adapter2.0.D_fc2.bias", "final_conv.0.weight", "final_conv.0.bias", "final_conv.1.weight", "final_conv.1.bias", "final_conv.1.running_mean", "final_conv.1.running_var", "final_conv.1.num_batches_tracked", "final_conv.3.weight", "final_conv.3.bias", "final_conv.4.weight", "final_conv.4.bias", "final_conv.4.running_mean", "final_conv.4.running_var", "final_conv.4.num_batches_tracked", "final_conv.6.weight", "final_conv.6.bias".

siyi-wind commented 8 months ago

Hi dysion, the weights are for AViT models. Here are the correspondance between models in code and in the paper.

Paper Code
BASE ViTSeg/ SwinSeg/ DeiTSeg
AViT ViTSeg_CNNprompt_adapt/ SwinSeg_CNNprompt_adapt/ DeiTSeg_CNNprompt_adapt

Thus, if you want to test AViT (ViT based), the following command should work

python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/
dysion commented 8 months ago

Hi siyi, I run the command as you gave:

python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/

and use model weights AViT_ViT_B_ISIC_best.pth.

The output is like:

In init ViTSeg_CNNprompt_adapt, config.model_adapt.adapt_method: False In vit_adapters, ViTSeg_CNNprompt_adapt, adapt_method: False 92.282625M total parameters 6.522369M total trainable parameters Traceback (most recent call last): File "test_avit.py", line 409, in model.load_state_dict(torch.load(model_dir)) File "/home/amax/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ViTSeg_CNNprompt_adapt: Unexpected key(s) in state_dict: "encoder.norm.weight", "encoder.norm.bias", "encoder.blocks.0.adapter1.0.D_fc1.weight", "encoder.blocks.0.adapter1.0.D_fc1.bias", "encoder.blocks.0.adapter1.0.D_fc2.weight", "encoder.blocks.0.adapter1.0.D_fc2.bias", "encoder.blocks.0.adapter2.0.D_fc1.weight", "encoder.blocks.0.adapter2.0.D_fc1.bias", "encoder.blocks.0.adapter2.0.D_fc2.weight", "encoder.blocks.0.adapter2.0.D_fc2.bias", "encoder.blocks.1.adapter1.0.D_fc1.weight", "encoder.blocks.1.adapter1.0.D_fc1.bias", "encoder.blocks.1.adapter1.0.D_fc2.weight", "encoder.blocks.1.adapter1.0.D_fc2.bias", "encoder.blocks.1.adapter2.0.D_fc1.weight", "encoder.blocks.1.adapter2.0.D_fc1.bias", "encoder.blocks.1.adapter2.0.D_fc2.weight", "encoder.blocks.1.adapter2.0.D_fc2.bias", "encoder.blocks.2.adapter1.0.D_fc1.weight", "encoder.blocks.2.adapter1.0.D_fc1.bias", "encoder.blocks.2.adapter1.0.D_fc2.weight", "encoder.blocks.2.adapter1.0.D_fc2.bias", "encoder.blocks.2.adapter2.0.D_fc1.weight", "encoder.blocks.2.adapter2.0.D_fc1.bias", "encoder.blocks.2.adapter2.0.D_fc2.weight", "encoder.blocks.2.adapter2.0.D_fc2.bias", "encoder.blocks.3.adapter1.0.D_fc1.weight", "encoder.blocks.3.adapter1.0.D_fc1.bias", "encoder.blocks.3.adapter1.0.D_fc2.weight", "encoder.blocks.3.adapter1.0.D_fc2.bias", "encoder.blocks.3.adapter2.0.D_fc1.weight", "encoder.blocks.3.adapter2.0.D_fc1.bias", "encoder.blocks.3.adapter2.0.D_fc2.weight", "encoder.blocks.3.adapter2.0.D_fc2.bias", "encoder.blocks.4.adapter1.0.D_fc1.weight", "encoder.blocks.4.adapter1.0.D_fc1.bias", "encoder.blocks.4.adapter1.0.D_fc2.weight", "encoder.blocks.4.adapter1.0.D_fc2.bias", "encoder.blocks.4.adapter2.0.D_fc1.weight", "encoder.blocks.4.adapter2.0.D_fc1.bias", "encoder.blocks.4.adapter2.0.D_fc2.weight", "encoder.blocks.4.adapter2.0.D_fc2.bias", "encoder.blocks.5.adapter1.0.D_fc1.weight", "encoder.blocks.5.adapter1.0.D_fc1.bias", "encoder.blocks.5.adapter1.0.D_fc2.weight", "encoder.blocks.5.adapter1.0.D_fc2.bias", "encoder.blocks.5.adapter2.0.D_fc1.weight", "encoder.blocks.5.adapter2.0.D_fc1.bias", "encoder.blocks.5.adapter2.0.D_fc2.weight", "encoder.blocks.5.adapter2.0.D_fc2.bias", "encoder.blocks.6.adapter1.0.D_fc1.weight", "encoder.blocks.6.adapter1.0.D_fc1.bias", "encoder.blocks.6.adapter1.0.D_fc2.weight", "encoder.blocks.6.adapter1.0.D_fc2.bias", "encoder.blocks.6.adapter2.0.D_fc1.weight", "encoder.blocks.6.adapter2.0.D_fc1.bias", "encoder.blocks.6.adapter2.0.D_fc2.weight", "encoder.blocks.6.adapter2.0.D_fc2.bias", "encoder.blocks.7.adapter1.0.D_fc1.weight", "encoder.blocks.7.adapter1.0.D_fc1.bias", "encoder.blocks.7.adapter1.0.D_fc2.weight", "encoder.blocks.7.adapter1.0.D_fc2.bias", "encoder.blocks.7.adapter2.0.D_fc1.weight", "encoder.blocks.7.adapter2.0.D_fc1.bias", "encoder.blocks.7.adapter2.0.D_fc2.weight", "encoder.blocks.7.adapter2.0.D_fc2.bias", "encoder.blocks.8.adapter1.0.D_fc1.weight", "encoder.blocks.8.adapter1.0.D_fc1.bias", "encoder.blocks.8.adapter1.0.D_fc2.weight", "encoder.blocks.8.adapter1.0.D_fc2.bias", "encoder.blocks.8.adapter2.0.D_fc1.weight", "encoder.blocks.8.adapter2.0.D_fc1.bias", "encoder.blocks.8.adapter2.0.D_fc2.weight", "encoder.blocks.8.adapter2.0.D_fc2.bias", "encoder.blocks.9.adapter1.0.D_fc1.weight", "encoder.blocks.9.adapter1.0.D_fc1.bias", "encoder.blocks.9.adapter1.0.D_fc2.weight", "encoder.blocks.9.adapter1.0.D_fc2.bias", "encoder.blocks.9.adapter2.0.D_fc1.weight", "encoder.blocks.9.adapter2.0.D_fc1.bias", "encoder.blocks.9.adapter2.0.D_fc2.weight", "encoder.blocks.9.adapter2.0.D_fc2.bias", "encoder.blocks.10.adapter1.0.D_fc1.weight", "encoder.blocks.10.adapter1.0.D_fc1.bias", "encoder.blocks.10.adapter1.0.D_fc2.weight", "encoder.blocks.10.adapter1.0.D_fc2.bias", "encoder.blocks.10.adapter2.0.D_fc1.weight", "encoder.blocks.10.adapter2.0.D_fc1.bias", "encoder.blocks.10.adapter2.0.D_fc2.weight", "encoder.blocks.10.adapter2.0.D_fc2.bias", "encoder.blocks.11.adapter1.0.D_fc1.weight", "encoder.blocks.11.adapter1.0.D_fc1.bias", "encoder.blocks.11.adapter1.0.D_fc2.weight", "encoder.blocks.11.adapter1.0.D_fc2.bias", "encoder.blocks.11.adapter2.0.D_fc1.weight", "encoder.blocks.11.adapter2.0.D_fc1.bias", "encoder.blocks.11.adapter2.0.D_fc2.weight", "encoder.blocks.11.adapter2.0.D_fc2.bias".

siyi-wind commented 8 months ago

Hi Dysion, the adapt_method for AViT should be MLP. As the command I gave in the repository # ViTSeg_CNNprompt_adapt, SwinSeg_CNNprompt_adapt, DeiTSeg_CNNprompt_adapt

python -u multi_train_adapt.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method MLP --num_domains 1 --dataset isic2018 --k_fold 0

The input for AViT should be an image and a domain label, which can be written as output = model(img,d=d)['seg']. Since the input data is from one domain(dataset), d is '0'. You could refer to multi_train_adapt.py to get more information about how to run the model

dysion commented 8 months ago

Hi! When changing the --adapt_method from False to MLP, there are still some errors:

Traceback (most recent call last): File "test_avit.py", line 409, in model.load_state_dict(torch.load(model_dir)) File "/home/amax/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ViTSeg_CNNprompt_adapt: Unexpected key(s) in state_dict: "encoder.norm.weight", "encoder.norm.bias".

The error is about load_state_dict(torch.load(model_dir)). My question is whether the weights file is agreed to the ViTSeg_CNNprompt_adapt model ?

Besides, I init the model as follows, where I changed para as pretrained=False:

model = ViTSeg_CNNprompt_adapt(pretrained=False, pretrained_vit_name=config.vit.name, pretrained_folder=config.pretrained_folder,img_size=config.data.img_size, patch_size=config.vit.patch_size, embed_dim=config.vit.embed_dim, depth=config.vit.depth, num_heads=config.vit.num_heads, mlp_ratio=config.vit.mlp_ratio, drop_rate=config.vit.dropout_rate, attn_drop_rate=config.vit.attention_dropout_rate, drop_path_rate=0.2, debug=config.debug, adapt_method=config.model_adapt.adapt_method, num_domains=K)

Thank you!

siyi-wind commented 8 months ago

Hi Dysion, sorry to cause the confusion. The encoder.norm is an unused norm during training. I have fixed the code by adding an unused norm in the model to adapt the model architecture to the weights. If you encountered this issue in the future, please just ignore it by not loading the norm.

Besides, I just uploaded a play.ipynb file, which tested AViT on ISIC as an example. Please let me know if you face any issues.

dysion commented 8 months ago

Hi Dysion, sorry to cause the confusion. The encoder.norm is an unused norm during training. I have fixed the code by adding an unused norm in the model to adapt the model architecture to the weights. If you encountered this issue in the future, please just ignore it by not loading the norm.

Besides, I just uploaded a play.ipynb file, which tested AViT on ISIC as an example. Please let me know if you face any issues.

Thank you! It worked now.