mazurowski-lab / finetune-SAM

This is an official repo for fine-tuning SAM to customized medical images.
https://arxiv.org/abs/2404.09957
Apache License 2.0
138 stars 22 forks source link

finetune based on SAM2 or SAM? #27

Open Tony-Duan2020 opened 4 days ago

Tony-Duan2020 commented 4 days ago

Hi,

Thanks for your wonderful job.

With the recent release of SAM2, I am wondering whether it would be more beneficial to fine-tune SAM or SAM2?

Could you kindly provide your thoughts or any advice regarding this?

Best regards!

Tony-Duan2020 commented 2 days ago

Hi,

I run "SingleGPU_train_finetune_box.py" and save a ckpt as below:

state_dict = sam.state_dict()
save_path = os.path.join(model_saving_path, f'epoch{epoch}_score_{round(dsc.item(), 4)}.pth')
torch.save(state_dict, save_path)

Then I want to export this ckpt as a onnx by using this code but failed,

https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py

and got the error as below:

RuntimeError: Error(s) in loading state_dict for Sam:
    Missing key(s) in state_dict: "mask_decoder.output_hypernetworks_mlps.3.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.2.bias". 
    Unexpected key(s) in state_dict: "mask_decoder.transformer.layers.0.MLP_Adapter.D_fc1.weight", "mask_decoder.transformer.layers.0.MLP_Adapter.D_fc1.bias", "mask_decoder.transformer.layers.0.MLP_Adapter.D_fc2.weight", "mask_decoder.transformer.layers.0.MLP_Adapter.D_fc2.bias", "mask_decoder.transformer.layers.0.Adapter.D_fc1.weight", "mask_decoder.transformer.layers.0.Adapter.D_fc1.bias", "mask_decoder.transformer.layers.0.Adapter.D_fc2.weight", "mask_decoder.transformer.layers.0.Adapter.D_fc2.bias", "mask_decoder.transformer.layers.1.MLP_Adapter.D_fc1.weight", "mask_decoder.transformer.layers.1.MLP_Adapter.D_fc1.bias", "mask_decoder.transformer.layers.1.MLP_Adapter.D_fc2.weight", "mask_decoder.transformer.layers.1.MLP_Adapter.D_fc2.bias", "mask_decoder.transformer.layers.1.Adapter.D_fc1.weight", "mask_decoder.transformer.layers.1.Adapter.D_fc1.bias", "mask_decoder.transformer.layers.1.Adapter.D_fc2.weight", "mask_decoder.transformer.layers.1.Adapter.D_fc2.bias". 
    size mismatch for mask_decoder.mask_tokens.weight: copying a param with shape torch.Size([3, 256]) from checkpoint, the shape in current model is torch.Size([4, 256]).
    size mismatch for mask_decoder.iou_prediction_head.layers.2.weight: copying a param with shape torch.Size([3, 256]) from checkpoint, the shape in current model is torch.Size([4, 256]).
    size mismatch for mask_decoder.iou_prediction_head.layers.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([4]).
haoyudong-97 commented 1 day ago

Hi,

For your first question, you can check our analysis on SAM 2. TLDR: SAM and SAM 2 have similar performance when doing 2D task.

For the second question, the error is caused by different number of classes. You can set args.num_cls = 3 if you want to apply it to the original sam (https://github.com/mazurowski-lab/finetune-SAM/blob/e41f73287329ea68df3bfb153c867bdec4eb23b2/cfg.py#L20C5-L20C131).