NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
9.43k stars 1.44k forks source link

Multiclass Segmentation using SAM #315

Open DhruvAwasthi opened 1 year ago

DhruvAwasthi commented 1 year ago

Can segment anything model be used for finetuning on a multi-class segmentation task? Thanks!

sharonsalabiglossai commented 1 year ago

Did you succeed?

Zahoor-Ahmad commented 1 year ago

Does anyone succeed with multiclass?

rafaelagrc commented 1 year ago

I am also interested in this.

agentdr1 commented 1 year ago

+1

NielsRogge commented 1 year ago

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

rafaelagrc commented 1 year ago

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

Hello @NielsRogge My question is if it is possible to finetune the model to identify different entities given a single prompt?
For example, considering this image, if I select a bounding box containing just the fruit basket as input prompt, how can I have multi-class segmentation that identifies each one of the fruits? I want to do the same but for a more specific case: identification of different structures in medical images.

image

Thank you for your attention :)

TAUIL-Abd-Elilah commented 9 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)
jamesheatonrdm commented 8 months ago

@TAUIL-Abd-Elilah how did you load the sam model in the first place?. And how do you change num_multimask_ouputs? It is not explained

cristian-cmyk4 commented 8 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)

Hi, can you explain ?

felixvh commented 8 months ago

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)

sam.load_state_dict(state_dict, strict=False)

@TAUIL-Abd-Elilah Can you explain a bit more detailed how you did this?

felixvh commented 8 months ago

I solved it after some time. The approach is slightly different from @TAUIL-Abd-Elilah.

I first load a SAM model based on my desired architecture.

# Initializing SAM vision, SAM Q-Former and language model configurations
vision_config = SamVisionConfig()
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig(num_multimask_outputs=4)
config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
# Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
emptyModel = SamModel(config)

This model, however, does not have the trained weights that you might need for fine tuning. Therefore, I load the previously saved weights into my model. Before that I need to adjust mask_tokens such that the sizes fit. And after loading the weights, I need to adjust the again the mask_tokens to the previos one.

size = (4, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

state_dict = torch.load("model_weights.pth")
emptyModel.load_state_dict(state_dict, strict=False)
size = (5, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

model = emptyModel

Hope this helps the others here!