MedicineToken / Medical-SAM2

Medical SAM 2: Segment Medical Images As Video Via Segment Anything Model 2
Apache License 2.0
330 stars 42 forks source link

Using pretrained weights on REFUGE dataset does not give good results #9

Closed jacoblam3112 closed 2 weeks ago

jacoblam3112 commented 1 month ago

Hi, Thank you for sharing your code. I downloaded the pretrained weights (MedSAM2_pretrain.pth) from the link you provided and loaded them in the model and run the evaluation via file train_2d.py on the REFUGE dataset (also downloaded from your provided link) without any finetuning. I was hoping the pretrained model would give decent results so I could make sure the inference pipeline works ok. But I see very low numbers as give below:

INFO:root:Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.
Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.

For your reference, I slightly modify the code in train_2d.py to directly do the validation instead of training first epoch. The rest of the arguments used are as below:

Namespace(b=1, data_path='./dataset/REFUGE', dataset='REFUGE', distributed='none', encoder='vit_b', exp_name='REFUGE_MedSAM2', gpu=True, gpu_device=0, image_size=1024, lr=0.0001, memory_bank_size=16, multimask_output=1, net='sam2', out_size=1024, path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Samples'}, pretrain='MedSAM2_pretrain.pth', prompt='bbox', prompt_freq=2, sam_ckpt='./checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', train_vis=False, val_freq=1, video_length=2, vis=True, weights=0)

Could you advise what could be the problem here ?

Thank you

BA88 commented 4 weeks ago

Thank you, @jiayuanz3, @YunliQi , @WuJunde for developing and sharing this code. We are very excited to try to segment RGB images without needing to prompt on every slice.

I have the same issue that @jacoblam3112 reports. Similar to what they did, I did this:

Next, I trained my own model using train_2d.py. My training results are:

     weights_dict ::
            epoch     = 61
            model     = sam2
            best_dice = 0.8716206283711675
            best_tol  = 0.10188701003789902

Using this model that I trained, I did this:

Do you have code that runs inference on never-seen data, eg, REFUGE Validation dataset, from a model saved as shown in train_2d.py:

            if  edice > best_dice:
                best_dice = edice
                best_tol = tol
                is_best = True

                save_checkpoint({
                'epoch': epoch + 1,
                'model': args.net,
                'state_dict': sd,
                'optimizer': optimizer.state_dict(),
                'best_dice': best_dice,
                'best_tol': best_tol,
                'path_helper': args.path_helper,
            }, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth")
            else:
                is_best = False

I am happy to test and share back to you any code you might provide. Thank you.

WuJunde commented 4 weeks ago

the pertained weight is for 3d, not 2d...

2d weight would still take a while to release, since it has a bug in saving and loading the weights

@jiayuanz3 is working on this

BA88 commented 4 weeks ago

Great, thank you so much for the update! Best wishes. Nice job!

jiayuanz3 commented 4 weeks ago

Sorry for the late reply. I'm still investigating the weird error in saving and loading weights 😢 For your current work, you might need to run the training with validation directly after each epoch. The visualisation results can be set to save at the same time. Thanks for pointing this out and I apologise again for any inconvenience brought to you 😞

owenip commented 4 weeks ago

Hey @jiayuanz3 , I am also having the problem with saving and loading the SAM2 weight but my fine tuned weights are for original 2D segmentation task, not approach from the Medical-SAM2. Does your fine tuned model work fine during training but perform weirly with loading the wieght from pt files?

jiayuanz3 commented 4 weeks ago

Yes... I don't know why it happens. But if you evaluate directly after each training epoch, it works fine. I'll update if I have any progress. Thanks for pointing out!

owenip commented 3 weeks ago

That is exactly my situation except I am fine tuning the mask decoder only. So far I tried only loading the state_dict of mask decoder, saving/loading the entire model with torch. No luck so far. I will update here if making any progress. Thanks.

1275468127 commented 2 weeks ago

Hey @owenip ,I am also experiencing an issue with saving and loading the SAM2 weights. Have you resolved it yet?

jiayuanz3 commented 2 weeks ago

Sorry again for the delayed error fixing process and it should work now. For validation, you can use the function.validation_sam in train_2d.py. Just simply replace the -sam_ckpt from SAM2's weight to your pretrained weight 😯