facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.42k stars 5.61k forks source link

Is it possible to change value of num_multimask_outputs to greater than 3? #299

Open DavidLanders95 opened 1 year ago

DavidLanders95 commented 1 year ago

Is is possible to change the number of possible mask outputs from the model to another value greater than 3? I see in the class MaskDecoder there is an option:

class MaskDecoder(nn.Module):

    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
    ) 

However, when I try and edit this value to say 5 or 10, I receive an error:

RuntimeError: Error(s) in loading state_dict for Sam: Missing key(s) in state_dict: "mask_decoder.output_hypernetworks_mlps.4.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.4.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.4.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.4.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.4.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.4.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.5.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.5.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.5.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.5.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.5.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.5.layers.2.bias". size mismatch for mask_decoder.mask_tokens.weight: copying a param with shape torch.Size([4, 256]) from checkpoint, the shape in current model is torch.Size([6, 256]). size mismatch for mask_decoder.iou_prediction_head.layers.2.weight: copying a param with shape torch.Size([4, 256]) from checkpoint, the shape in current model is torch.Size([6, 256]). size mismatch for mask_decoder.iou_prediction_head.layers.2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([6]).

Does anyone know if this is possible to solve?

starsky0426 commented 1 year ago

Yes you can change it, but you need to train the network again. 😅

EveningLin commented 1 year ago

self.num_mask_tokens why is 4?

DavidLanders95 commented 1 year ago

Yes you can change it, but you need to train the network again. 😅

Thanks, I thought that might be the case.

self.num_mask_tokens why is 4?

I'm not sure why, it's one more than the number of multimask outputs for some reason. It's set in the code.

halleewong commented 1 year ago

self.num_mask_tokens why is 4?

When multimask_output is False, the predicted mask associated with the first mask token is returned. Otherwise, the masks associated with the other 3 masks tokens are returned.

Soulergonote commented 1 year ago

Is there a way to generate more than 3 masks for one input without retraining the network ?