Closed dysion closed 2 months ago
Sure. I have uploaded the weights for AViT based on ViT-Base.
Thank you!
Hi! I tried to use the published weights to do some tests, but encountered errors.
The bash command is : python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/
The test model is ViTSeg, the weights file is AViT_ViT_B_ISIC_best.pth. When model.load_state_dict(torch.load(model_dir)), errors are:
RuntimeError: Error(s) in loading state_dict for ViTSeg_adapt: Missing key(s) in state_dict: "final_conv.weight", "final_conv.bias". Unexpected key(s) in state_dict: "prompt_encoder.conv1.weight", "prompt_encoder.bn1.weight", "prompt_encoder.bn1.bias", "prompt_encoder.bn1.running_mean", "prompt_encoder.bn1.running_var", "prompt_encoder.bn1.num_batches_tracked", "prompt_encoder.layer1.0.conv1.weight", "prompt_encoder.layer1.0.bn1.weight", "prompt_encoder.layer1.0.bn1.bias", "prompt_encoder.layer1.0.bn1.running_mean", "prompt_encoder.layer1.0.bn1.running_var", "prompt_encoder.layer1.0.bn1.num_batches_tracked", "prompt_encoder.layer1.0.conv2.weight", "prompt_encoder.layer1.0.bn2.weight", "prompt_encoder.layer1.0.bn2.bias", "prompt_encoder.layer1.0.bn2.running_mean", "prompt_encoder.layer1.0.bn2.running_var", "prompt_encoder.layer1.0.bn2.num_batches_tracked", "prompt_encoder.layer1.1.conv1.weight", "prompt_encoder.layer1.1.bn1.weight", "prompt_encoder.layer1.1.bn1.bias", "prompt_encoder.layer1.1.bn1.running_mean", "prompt_encoder.layer1.1.bn1.running_var", "prompt_encoder.layer1.1.bn1.num_batches_tracked", "prompt_encoder.layer1.1.conv2.weight", "prompt_encoder.layer1.1.bn2.weight", "prompt_encoder.layer1.1.bn2.bias", "prompt_encoder.layer1.1.bn2.running_mean", "prompt_encoder.layer1.1.bn2.running_var", "prompt_encoder.layer1.1.bn2.num_batches_tracked", "prompt_encoder.layer1.2.conv1.weight", "prompt_encoder.layer1.2.bn1.weight", "prompt_encoder.layer1.2.bn1.bias", "prompt_encoder.layer1.2.bn1.running_mean", "prompt_encoder.layer1.2.bn1.running_var", "prompt_encoder.layer1.2.bn1.num_batches_tracked", "prompt_encoder.layer1.2.conv2.weight", "prompt_encoder.layer1.2.bn2.weight", "prompt_encoder.layer1.2.bn2.bias", "prompt_encoder.layer1.2.bn2.running_mean", "prompt_encoder.layer1.2.bn2.running_var", "prompt_encoder.layer1.2.bn2.num_batches_tracked", "encoder.norm.weight", "encoder.norm.bias", "encoder.blocks.0.adapter1.0.D_fc1.weight", "encoder.blocks.0.adapter1.0.D_fc1.bias", "encoder.blocks.0.adapter1.0.D_fc2.weight", "encoder.blocks.0.adapter1.0.D_fc2.bias", "encoder.blocks.0.adapter2.0.D_fc1.weight", "encoder.blocks.0.adapter2.0.D_fc1.bias", "encoder.blocks.0.adapter2.0.D_fc2.weight", "encoder.blocks.0.adapter2.0.D_fc2.bias", "encoder.blocks.1.adapter1.0.D_fc1.weight", "encoder.blocks.1.adapter1.0.D_fc1.bias", "encoder.blocks.1.adapter1.0.D_fc2.weight", "encoder.blocks.1.adapter1.0.D_fc2.bias", "encoder.blocks.1.adapter2.0.D_fc1.weight", "encoder.blocks.1.adapter2.0.D_fc1.bias", "encoder.blocks.1.adapter2.0.D_fc2.weight", "encoder.blocks.1.adapter2.0.D_fc2.bias", "encoder.blocks.2.adapter1.0.D_fc1.weight", "encoder.blocks.2.adapter1.0.D_fc1.bias", "encoder.blocks.2.adapter1.0.D_fc2.weight", "encoder.blocks.2.adapter1.0.D_fc2.bias", "encoder.blocks.2.adapter2.0.D_fc1.weight", "encoder.blocks.2.adapter2.0.D_fc1.bias", "encoder.blocks.2.adapter2.0.D_fc2.weight", "encoder.blocks.2.adapter2.0.D_fc2.bias", "encoder.blocks.3.adapter1.0.D_fc1.weight", "encoder.blocks.3.adapter1.0.D_fc1.bias", "encoder.blocks.3.adapter1.0.D_fc2.weight", "encoder.blocks.3.adapter1.0.D_fc2.bias", "encoder.blocks.3.adapter2.0.D_fc1.weight", "encoder.blocks.3.adapter2.0.D_fc1.bias", "encoder.blocks.3.adapter2.0.D_fc2.weight", "encoder.blocks.3.adapter2.0.D_fc2.bias", "encoder.blocks.4.adapter1.0.D_fc1.weight", "encoder.blocks.4.adapter1.0.D_fc1.bias", "encoder.blocks.4.adapter1.0.D_fc2.weight", "encoder.blocks.4.adapter1.0.D_fc2.bias", "encoder.blocks.4.adapter2.0.D_fc1.weight", "encoder.blocks.4.adapter2.0.D_fc1.bias", "encoder.blocks.4.adapter2.0.D_fc2.weight", "encoder.blocks.4.adapter2.0.D_fc2.bias", "encoder.blocks.5.adapter1.0.D_fc1.weight", "encoder.blocks.5.adapter1.0.D_fc1.bias", "encoder.blocks.5.adapter1.0.D_fc2.weight", "encoder.blocks.5.adapter1.0.D_fc2.bias", "encoder.blocks.5.adapter2.0.D_fc1.weight", "encoder.blocks.5.adapter2.0.D_fc1.bias", "encoder.blocks.5.adapter2.0.D_fc2.weight", "encoder.blocks.5.adapter2.0.D_fc2.bias", "encoder.blocks.6.adapter1.0.D_fc1.weight", "encoder.blocks.6.adapter1.0.D_fc1.bias", "encoder.blocks.6.adapter1.0.D_fc2.weight", "encoder.blocks.6.adapter1.0.D_fc2.bias", "encoder.blocks.6.adapter2.0.D_fc1.weight", "encoder.blocks.6.adapter2.0.D_fc1.bias", "encoder.blocks.6.adapter2.0.D_fc2.weight", "encoder.blocks.6.adapter2.0.D_fc2.bias", "encoder.blocks.7.adapter1.0.D_fc1.weight", "encoder.blocks.7.adapter1.0.D_fc1.bias", "encoder.blocks.7.adapter1.0.D_fc2.weight", "encoder.blocks.7.adapter1.0.D_fc2.bias", "encoder.blocks.7.adapter2.0.D_fc1.weight", "encoder.blocks.7.adapter2.0.D_fc1.bias", "encoder.blocks.7.adapter2.0.D_fc2.weight", "encoder.blocks.7.adapter2.0.D_fc2.bias", "encoder.blocks.8.adapter1.0.D_fc1.weight", "encoder.blocks.8.adapter1.0.D_fc1.bias", "encoder.blocks.8.adapter1.0.D_fc2.weight", "encoder.blocks.8.adapter1.0.D_fc2.bias", "encoder.blocks.8.adapter2.0.D_fc1.weight", "encoder.blocks.8.adapter2.0.D_fc1.bias", "encoder.blocks.8.adapter2.0.D_fc2.weight", "encoder.blocks.8.adapter2.0.D_fc2.bias", "encoder.blocks.9.adapter1.0.D_fc1.weight", "encoder.blocks.9.adapter1.0.D_fc1.bias", "encoder.blocks.9.adapter1.0.D_fc2.weight", "encoder.blocks.9.adapter1.0.D_fc2.bias", "encoder.blocks.9.adapter2.0.D_fc1.weight", "encoder.blocks.9.adapter2.0.D_fc1.bias", "encoder.blocks.9.adapter2.0.D_fc2.weight", "encoder.blocks.9.adapter2.0.D_fc2.bias", "encoder.blocks.10.adapter1.0.D_fc1.weight", "encoder.blocks.10.adapter1.0.D_fc1.bias", "encoder.blocks.10.adapter1.0.D_fc2.weight", "encoder.blocks.10.adapter1.0.D_fc2.bias", "encoder.blocks.10.adapter2.0.D_fc1.weight", "encoder.blocks.10.adapter2.0.D_fc1.bias", "encoder.blocks.10.adapter2.0.D_fc2.weight", "encoder.blocks.10.adapter2.0.D_fc2.bias", "encoder.blocks.11.adapter1.0.D_fc1.weight", "encoder.blocks.11.adapter1.0.D_fc1.bias", "encoder.blocks.11.adapter1.0.D_fc2.weight", "encoder.blocks.11.adapter1.0.D_fc2.bias", "encoder.blocks.11.adapter2.0.D_fc1.weight", "encoder.blocks.11.adapter2.0.D_fc1.bias", "encoder.blocks.11.adapter2.0.D_fc2.weight", "encoder.blocks.11.adapter2.0.D_fc2.bias", "final_conv.0.weight", "final_conv.0.bias", "final_conv.1.weight", "final_conv.1.bias", "final_conv.1.running_mean", "final_conv.1.running_var", "final_conv.1.num_batches_tracked", "final_conv.3.weight", "final_conv.3.bias", "final_conv.4.weight", "final_conv.4.bias", "final_conv.4.running_mean", "final_conv.4.running_var", "final_conv.4.num_batches_tracked", "final_conv.6.weight", "final_conv.6.bias".
Hi dysion, the weights are for AViT models. Here are the correspondance between models in code and in the paper.
Paper | Code |
---|---|
BASE | ViTSeg/ SwinSeg/ DeiTSeg |
AViT | ViTSeg_CNNprompt_adapt/ SwinSeg_CNNprompt_adapt/ DeiTSeg_CNNprompt_adapt |
Thus, if you want to test AViT (ViT based), the following command should work
python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/
Hi siyi, I run the command as you gave:
python test_avit.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method False --num_domains 1 --dataset isic2018 --k_fold 0 --test_dir /home/deeplearning/disk_2/dataset/test-scar/
and use model weights AViT_ViT_B_ISIC_best.pth.
The output is like:
In init ViTSeg_CNNprompt_adapt, config.model_adapt.adapt_method: False
In vit_adapters, ViTSeg_CNNprompt_adapt, adapt_method: False
92.282625M total parameters
6.522369M total trainable parameters
Traceback (most recent call last):
File "test_avit.py", line 409, in
Hi Dysion, the adapt_method for AViT should be MLP. As the command I gave in the repository
# ViTSeg_CNNprompt_adapt, SwinSeg_CNNprompt_adapt, DeiTSeg_CNNprompt_adapt
python -u multi_train_adapt.py --exp_name test --config_yml Configs/multi_train_local.yml --model ViTSeg_CNNprompt_adapt --batch_size 16 --adapt_method MLP --num_domains 1 --dataset isic2018 --k_fold 0
The input for AViT should be an image and a domain label, which can be written as output = model(img,d=d)['seg']
. Since the input data is from one domain(dataset), d is '0'. You could refer to multi_train_adapt.py
to get more information about how to run the model
Hi! When changing the --adapt_method from False to MLP, there are still some errors:
Traceback (most recent call last):
File "test_avit.py", line 409, in
The error is about load_state_dict(torch.load(model_dir)). My question is whether the weights file is agreed to the ViTSeg_CNNprompt_adapt model ?
Besides, I init the model as follows, where I changed para as pretrained=False:
model = ViTSeg_CNNprompt_adapt(pretrained=False, pretrained_vit_name=config.vit.name, pretrained_folder=config.pretrained_folder,img_size=config.data.img_size, patch_size=config.vit.patch_size, embed_dim=config.vit.embed_dim, depth=config.vit.depth, num_heads=config.vit.num_heads, mlp_ratio=config.vit.mlp_ratio, drop_rate=config.vit.dropout_rate, attn_drop_rate=config.vit.attention_dropout_rate, drop_path_rate=0.2, debug=config.debug, adapt_method=config.model_adapt.adapt_method, num_domains=K)
Thank you!
Hi Dysion, sorry to cause the confusion. The encoder.norm is an unused norm during training. I have fixed the code by adding an unused norm in the model to adapt the model architecture to the weights. If you encountered this issue in the future, please just ignore it by not loading the norm.
Besides, I just uploaded a play.ipynb file, which tested AViT on ISIC as an example. Please let me know if you face any issues.
Hi Dysion, sorry to cause the confusion. The encoder.norm is an unused norm during training. I have fixed the code by adding an unused norm in the model to adapt the model architecture to the weights. If you encountered this issue in the future, please just ignore it by not loading the norm.
Besides, I just uploaded a play.ipynb file, which tested AViT on ISIC as an example. Please let me know if you face any issues.
Thank you! It worked now.
Hi! Thanks for your wonderful work! Can you share the trained AViT model weights, I just want to have a test for interest. Thank you!