NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.48k stars 1.33k forks source link

Fine tuning SAM with multiple bounding boxes #331

Closed mohdsaqibxa closed 11 months ago

mohdsaqibxa commented 11 months ago

@NielsRogge I tried to fine tune SAM on a custom dataset using your notebook but with multiple bounding boxes instead of one bounding box.

https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb

Right now, I think this notebook is fine-tuning the SAM with only one bounding box prompt. While I tried to do it with multiple bounding boxes it gives me an error:

ValueError                                Traceback (most recent call last)
Cell In[31], line 23
     21 predicted_masks = outputs.pred_masks.squeeze(1)
     22 ground_truth_masks = batch["ground_truth_mask"].float().to(device)
---> 23 loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
     25 # backward pass (compute gradients of parameters w.r.t. loss)
     26 optimizer.zero_grad()

File /anaconda/envs/train_sam/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /anaconda/envs/train_sam/lib/python3.8/site-packages/monai/losses/dice.py:733, in DiceCELoss.forward(self, input, target)
    722 """
    723 Args:
    724     input: the shape should be BNH[WD].
   (...)
    730 
    731 """
    732 if len(input.shape) != len(target.shape):
--> 733     raise ValueError(
    734         "the number of dimensions for input and target should be the same, "
    735         f"got shape {input.shape} and {target.shape}."
    736     )
    738 dice_loss = self.dice(input, target)
    739 ce_loss = self.ce(input, target)

ValueError: the number of dimensions for input and target should be the same, got shape torch.Size([2, 3, 1, 256, 256]) and torch.Size([2, 1, 256, 256]).

Dimensions for one single image:

example = train_dataset[0]
for k,v in example.items():
      print(k,v.shape)

-----------------------------------------------------------------------------------------
pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([3, 4])
ground_truth_mask (256, 256)

Dimensions for a batch of 2:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

batch = next(iter(train_dataloader))
for k,v in batch.items():
      print(k,v.shape)

-----------------------------------------------------------------------------------------
pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 3, 4])
ground_truth_mask torch.Size([2, 256, 256])

Here, I am using 3 bounding boxes for each image.

Please help me with this. How can I fine tune with multiple bounding boxes? Also, please let me know if I am missing something here. Thanks in advance.

NielsRogge commented 11 months ago

Hi,

If you want to fine-tune SAM with multiple bounding boxes, you need to create several training examples, each containing a single (image, box, mask) triplet.

NielsRogge commented 11 months ago

Will close this issue as it's resolved, feel free to reopen.

nahidalam commented 10 months ago

@NielsRogge you mentioned we have to create triplets like (image, box, mask), So lets say I have 20 bounding boxes in an image. So the triplets will be

option 1:

(image, box1, mask)
(image, box2, mask)
...

Or option2:

(image, box1, mask1)
(image, box2, mask2)
...

option 2 makes sense but it is not clear to me how we would extract mask1, mask2 etc corresponding to the box from the given ground truth mask

sijie-Xu commented 8 months ago

@nahidalam Do you know which way to use?

sijie-Xu commented 8 months ago

(image, box1, mask) (image, box2, mask)

I've tried both of these methods and I've found that using all masks works better because the other way suppresses the generation of other masks of the same kind. 1699427320754 But the effect is also mediocre.This is the effect of training 4 rounds: 1699427478147 1699427506986

saiviswanth commented 7 months ago

Hey @sijie-Xu, I am also working on a project where I have to use multiple bounding boxes for the prompt to increase the segmentation for multiple objects in an image. Can I know how you created the triplets and can you publish the code? It will be really helpful. Thanks a lot @sijie-Xu in advance

sijie-Xu commented 7 months ago

you just created like it image

嘿,我也在做一个项目,我必须使用多个边界框作为提示,以增加图像中多个对象的分割。我能知道你是如何创建三元组的,你能发布代码吗?这将非常有帮助。提前非常感谢

JanvitaReddy11 commented 3 months ago

@sijie-Xu can you please explain your approach in detail