Open DhruvAwasthi opened 1 year ago
Did you succeed?
Does anyone succeed with multiclass?
I am also interested in this.
+1
Hi,
Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.
Hi,
Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.
Hello @NielsRogge
My question is if it is possible to finetune the model to identify different entities given a single prompt?
For example, considering this image, if I select a bounding box containing just the fruit basket as input prompt, how can I have multi-class segmentation that identifies each one of the fruits?
I want to do the same but for a more specific case: identification of different structures in medical images.
Thank you for your attention :)
Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.
state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']
for key in keys_to_remove:
state_dict.pop(key, None)
sam.load_state_dict(state_dict, strict=False)
@TAUIL-Abd-Elilah how did you load the sam model in the first place?. And how do you change num_multimask_ouputs? It is not explained
Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.
state_dict = torch.load(f) keys_to_remove = ['mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias'] for key in keys_to_remove: state_dict.pop(key, None) sam.load_state_dict(state_dict, strict=False)
Hi, can you explain ?
Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.
state_dict = torch.load(f) keys_to_remove = ['mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias'] for key in keys_to_remove: state_dict.pop(key, None) sam.load_state_dict(state_dict, strict=False)
@TAUIL-Abd-Elilah Can you explain a bit more detailed how you did this?
I solved it after some time. The approach is slightly different from @TAUIL-Abd-Elilah.
I first load a SAM model based on my desired architecture.
# Initializing SAM vision, SAM Q-Former and language model configurations
vision_config = SamVisionConfig()
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig(num_multimask_outputs=4)
config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
# Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
emptyModel = SamModel(config)
This model, however, does not have the trained weights that you might need for fine tuning. Therefore, I load the previously saved weights into my model. Before that I need to adjust mask_tokens such that the sizes fit. And after loading the weights, I need to adjust the again the mask_tokens to the previos one.
size = (4, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor
state_dict = torch.load("model_weights.pth")
emptyModel.load_state_dict(state_dict, strict=False)
size = (5, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor
model = emptyModel
Hope this helps the others here!
Can segment anything model be used for finetuning on a multi-class segmentation task? Thanks!