NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.48k stars 1.33k forks source link

Error while finetuning SAM #312

Open DhruvAwasthi opened 1 year ago

DhruvAwasthi commented 1 year ago

@NielsRogge Thanks a lot for your tutorial on how to fientune SAM!
I am trying to finetune SAM for multiclass segmentation following your notebook. Everything went fine until I create the dataloader and model. Now, when I try to train the model, I am running into this issue:

AssertionError: ground truth has different shape (torch.Size([2, 1, 1500, 1500])) from input (torch.Size([2, 1, 256, 256]))

I noticed that it is due to the fact that my ground truth mask is of resolution 1500 X 1500 while predicted mask from the output is of resolution 256 X 256. And this is causing error while calculating the seg_loss.
Do you have any idea on how to fix this please? Thank you!

NielsRogge commented 1 year ago

You can upsample the predicted masks using torch.nn.functional.interpolate (docs).

So predicted_masks = torch.nn.functional.interpolate(predicted_masks, size=(1500,1500))

DhruvAwasthi commented 1 year ago

Okay. I will try this.

On a side note, is it because of the config parameters? And if yes, is there no way to update that so that it outputs the masks of same resolution as the ground truth? I am just wondering, if upsampling doesn't harm the accuracy. Thank you@

Changesong commented 1 year ago

I tried the interpolation method and it works fine. Thank you. I also want to increase prediction resolution. I want the 4K images I have to be segmented as cleanly as possible, and 256 seems like it would be too small.

sharonsalabiglossai commented 1 year ago

Hi @NielsRogge,

I have been working using your notebook on finetuning SAM on a custom dataset. Do you know if it is possible to adapt it to multiclass?

Thank you

karthikdatta98 commented 11 months ago

@DhruvAwasthi Hey, I have one query, did you find a way how to get the mask which is equivalent to the ground truth? are there any config parameters to change? or other fine-tuning methods? also I wondered as SAM was pretrained on 1024x1024, using 1500x1500 hurt your predictions? Thanks

DhruvAwasthi commented 11 months ago

@karthikdatta98 While finetuning the SAM, I resized all my images to the desired target image size that I wanted the model to be trained on. And once fine-tuned, when using it for inference, I used these two functions:


import torch.nn.functional as F
from torch.nn.functional import threshold, normalize

def postprocess_masks(masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
                      image_size=1024) -> torch.Tensor:
    """
    Remove padding and upscale masks to the original image size.

    Args:
      masks (torch.Tensor):
        Batched masks from the mask_decoder, in BxCxHxW format.
      input_size (tuple(int, int)):
        The size of the image input to the model, in (H, W) format. Used to remove padding.
      original_size (tuple(int, int)):
        The original size of the image before resizing for input to the model, in (H, W) format.

    Returns:
      (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
        is given by original_size.
    """
    masks = F.interpolate(
        masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    masks = masks[..., : input_size[0], : input_size[1]]
    masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
    return masks

def get_model_response(model, inputs, input_resized_shape, input_original_shape):
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)
    processed_output = postprocess_masks(outputs["pred_masks"][0], input_resized_shape, input_original_shape)
    predicted_masks = normalize(threshold(processed_output, 0.0, 0)).squeeze(1)
    return predicted_masks[0]

Here, in the get_model_response() function, I am passing the fine-tuned model, the input as defined in the fine-tuning script, the input resized shape (the size of the image that I trained my model on), and the input original shape (the original image size of the input image).

Hope this helps!

karthikdatta98 commented 11 months ago

@DhruvAwasthi Hey Thanks for the quick reply. Wanted to know as SAM was pre-trained on 1024x1024, and if you used some other image resolution if it hurt the model predictions. Did the model do a good job? I heard that fine-tuning using 1024x1024 would be ideal. I am planning to train it on like 2500x3000 so that's why am asking the information before I start training. Also, can I know how huge was your dataset? did you like train for some 1000 images? During the inference does this model only give 256x256? and If upscaled to 2k or3k would it work well?

DhruvAwasthi commented 11 months ago

@karthikdatta98 It worked well. I fine-tuned with 1200x1200 images and the results were pretty good actually. I initially fine-tuned the model with 150 images just to get an idea of how fine-tuning works, and it worked well. Even if upscaled, this works really well. You can give it a shot.

karthikdatta98 commented 11 months ago

@DhruvAwasthi Thats great to know, are you open to share if your dataset worked well on SAM without finetuning? Did it have a drastic change after fine tuning? I am hoping your dataset wasn't a generic data set or classes that overlap with COCO..can I know what kind of data?

DhruvAwasthi commented 11 months ago

@karthikdatta98 The dataset I fine tuned on, contained the data that SAM hasn't been trained on, and without fine-tuning the results were not good but they improved a lot after fine-tuning.

karthikdatta98 commented 11 months ago

@DhruvAwasthi when you finetuned it with 1200x1200, your groundtruth was 256x256? or you upscaled the predictions to 1200x1200 keeping the gorundtruth also at 1200x1200?

DhruvAwasthi commented 11 months ago

@karthikdatta98 I upscaled it.

karthikdatta98 commented 11 months ago

@DhruvAwasthi, can you share your code for inference? like the main function. either here or at karthikdatta98@gmail.com Thank you so much, it would be a great help, am currently working on a project...

karthikdatta98 commented 11 months ago

@DhruvAwasthi also, after I did, torch.save, how do I load my model.pth? what was the command you used? thanks.

NielsRogge commented 11 months ago

Hi @karthikdatta98 you can use model.save_pretrained("path_to_directory") and model = SamModel.from_pretrained("path_to_directory"), no need for torch.save when using HF models.

Raspberry-beans commented 5 months ago

@karthikdatta98 It worked well. I fine-tuned with 1200x1200 images and the results were pretty good actually. I initially fine-tuned the model with 150 images just to get an idea of how fine-tuning works, and it worked well. Even if upscaled, this works really well. You can give it a shot.

Hi I hope you will be fine.

I am also fine tuning mask-decoder on my custom medical images. There are 165 images of (768, 1024) size. I also upscaled the predictions from (256, 256) to (768, 1024). Trained for 50 epochs and loss value was around 0.11.

Unlike your experience of observing significant improvement after fine-tuning, I received very poor results from my fine-tuned model even after giving point prompts. Although ideally, I wanted to see improvement even without giving any prompt.

Did you also use any prompts during inference? I was at least expecting to receive good (not perfect) results after tuning, but the model seems to perform poorly even compared to non fine-tuned version.