NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
9.16k stars 1.42k forks source link

Multiclass Segmentation using SAM #315

Open DhruvAwasthi opened 1 year ago

DhruvAwasthi commented 1 year ago

Can segment anything model be used for finetuning on a multi-class segmentation task? Thanks!

sharonsalabiglossai commented 1 year ago

Did you succeed?

Zahoor-Ahmad commented 11 months ago

Does anyone succeed with multiclass?

rafaelagrc commented 11 months ago

I am also interested in this.

agentdr1 commented 11 months ago

+1

NielsRogge commented 11 months ago

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

rafaelagrc commented 11 months ago

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

Hello @NielsRogge My question is if it is possible to finetune the model to identify different entities given a single prompt?
For example, considering this image, if I select a bounding box containing just the fruit basket as input prompt, how can I have multi-class segmentation that identifies each one of the fruits? I want to do the same but for a more specific case: identification of different structures in medical images.

image

Thank you for your attention :)

TAUIL-Abd-Elilah commented 8 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)
jamesheatonrdm commented 7 months ago

@TAUIL-Abd-Elilah how did you load the sam model in the first place?. And how do you change num_multimask_ouputs? It is not explained

cristian-cmyk4 commented 7 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)

Hi, can you explain ?

felixvh commented 7 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)

@TAUIL-Abd-Elilah Can you explain a bit more detailed how you did this?

felixvh commented 7 months ago

I solved it after some time. The approach is slightly different from @TAUIL-Abd-Elilah.

I first load a SAM model based on my desired architecture.

# Initializing SAM vision, SAM Q-Former and language model configurations
vision_config = SamVisionConfig()
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig(num_multimask_outputs=4)
config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
# Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
emptyModel = SamModel(config)

This model, however, does not have the trained weights that you might need for fine tuning. Therefore, I load the previously saved weights into my model. Before that I need to adjust mask_tokens such that the sizes fit. And after loading the weights, I need to adjust the again the mask_tokens to the previos one.

size = (4, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

state_dict = torch.load("model_weights.pth")
emptyModel.load_state_dict(state_dict, strict=False)
size = (5, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

model = emptyModel

Hope this helps the others here!