Qsingle / LearnablePromptSAM

Try to use the SAM-ViT as the backbone to create the learnable prompt for semantic segmentation
Apache License 2.0
82 stars 14 forks source link

problem #16

Open liuadan opened 3 months ago

liuadan commented 3 months ago

After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?

Qsingle commented 2 months ago

After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?

Did your checkpoint is downloaded from the site. Please make sure the checkpoint is correct.

liuadan commented 2 months ago

yes im sure

Qsingle commented 2 months ago

yes im sure

When the checkpoint is not correct, then this wrong occurred. Could you provide the command and the whole log? Thank you very much.

liuadan commented 2 months ago
python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd
Qsingle commented 2 months ago

python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd

Sorry, please provide the log of the training. You can try this way to check the checkpoint.

python -c "import torch; print(torch.load('C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth ').keys())"

If the checkpoint is correct. I suggest you to change the command to

python train_learnable_sam.py --image train/image_cut --mask_path train/gt_cut --model_name vit_b --checkpoint ckpts/sam_vit_b_01ec64.pth --save_path ckpts --lr 0.05 --mix_precision --optimizer sgd
CNwuyueyu commented 2 weeks ago

You maybe choice wrong model.If you trained SAM model ,Use PromptSAM besides PromptDiNo in the model for segmentation.