Closed jacoblam3112 closed 2 weeks ago
Thank you, @jiayuanz3, @YunliQi , @WuJunde for developing and sharing this code. We are very excited to try to segment RGB images without needing to prompt on every slice.
I have the same issue that @jacoblam3112 reports. Similar to what they did, I did this:
validation_sam
only on REFUGE Test dataset
Next, I trained my own model using train_2d.py
. My training results are:
weights_dict ::
epoch = 61
model = sam2
best_dice = 0.8716206283711675
best_tol = 0.10188701003789902
Using this model that I trained, I did this:
validation_sam
only on REFUGE Validation dataset
validation_sam
only on REFUGE Test dataset
Do you have code that runs inference on never-seen data, eg, REFUGE Validation dataset, from a model saved as shown in train_2d.py
:
if edice > best_dice:
best_dice = edice
best_tol = tol
is_best = True
save_checkpoint({
'epoch': epoch + 1,
'model': args.net,
'state_dict': sd,
'optimizer': optimizer.state_dict(),
'best_dice': best_dice,
'best_tol': best_tol,
'path_helper': args.path_helper,
}, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth")
else:
is_best = False
I am happy to test and share back to you any code you might provide. Thank you.
the pertained weight is for 3d, not 2d...
2d weight would still take a while to release, since it has a bug in saving and loading the weights
@jiayuanz3 is working on this
Great, thank you so much for the update! Best wishes. Nice job!
Sorry for the late reply. I'm still investigating the weird error in saving and loading weights 😢 For your current work, you might need to run the training with validation directly after each epoch. The visualisation results can be set to save at the same time. Thanks for pointing this out and I apologise again for any inconvenience brought to you 😞
Hey @jiayuanz3 , I am also having the problem with saving and loading the SAM2 weight but my fine tuned weights are for original 2D segmentation task, not approach from the Medical-SAM2. Does your fine tuned model work fine during training but perform weirly with loading the wieght from pt files?
Yes... I don't know why it happens. But if you evaluate directly after each training epoch, it works fine. I'll update if I have any progress. Thanks for pointing out!
That is exactly my situation except I am fine tuning the mask decoder only. So far I tried only loading the state_dict of mask decoder, saving/loading the entire model with torch. No luck so far. I will update here if making any progress. Thanks.
Hey @owenip ,I am also experiencing an issue with saving and loading the SAM2 weights. Have you resolved it yet?
Sorry again for the delayed error fixing process and it should work now. For validation, you can use the function.validation_sam in train_2d.py. Just simply replace the -sam_ckpt from SAM2's weight to your pretrained weight 😯
Hi, Thank you for sharing your code. I downloaded the pretrained weights (MedSAM2_pretrain.pth) from the link you provided and loaded them in the model and run the evaluation via file train_2d.py on the REFUGE dataset (also downloaded from your provided link) without any finetuning. I was hoping the pretrained model would give decent results so I could make sure the inference pipeline works ok. But I see very low numbers as give below:
INFO:root:Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.
Total score: 1.5592443943023682, IOU: 0.01751479435546341, DICE: 0.029827685931491824 || @ epoch 0.
For your reference, I slightly modify the code in train_2d.py to directly do the validation instead of training first epoch. The rest of the arguments used are as below:
Namespace(b=1, data_path='./dataset/REFUGE', dataset='REFUGE', distributed='none', encoder='vit_b', exp_name='REFUGE_MedSAM2', gpu=True, gpu_device=0, image_size=1024, lr=0.0001, memory_bank_size=16, multimask_output=1, net='sam2', out_size=1024, path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_08_13_07_04_47/Samples'}, pretrain='MedSAM2_pretrain.pth', prompt='bbox', prompt_freq=2, sam_ckpt='./checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', train_vis=False, val_freq=1, video_length=2, vis=True, weights=0)
Could you advise what could be the problem here ?
Thank you