facebookresearch / segment-anything-2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
10.71k stars 863 forks source link

Question about the mask decoder: why discard the first mask in multimask? #283

Open jeezrick opened 1 week ago

jeezrick commented 1 week ago
        # Select the correct mask or masks for output
        if multimask_output:
            masks = masks[:, 1:, :, :] 
            iou_pred = iou_pred[:, 1:]
        elif self.dynamic_multimask_via_stability and not self.training:
            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
        else:
            masks = masks[:, 0:1, :, :]
            iou_pred = iou_pred[:, 0:1]

        if multimask_output and self.use_multimask_token_for_obj_ptr:
            sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
        else:
            # Take the mask output token. Here we *always* use the token for single mask output.
            # At test time, even if we track after 1-click (and using multimask_output=True),
            # we still take the single mask token here. The rationale is that we always track
            # after multiple clicks during training, so the past tokens seen during training
            # are always the single mask token (and we'll let it be the object-memory token).
            sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape

link to code

I wonder why just discard the first mask in masks(multiple mask)? Is it because the first mask only used for single mask output in training so it doesn't apply to multimask output in inference? I don't think it's in the paper. Maybe I missed it, does anyone has an answer? Thanks.

heyoeyo commented 1 week ago

The design seems to carry over from the v1 model, and the paper for that model describes the reasoning in more detail in the appendix (page 17, section: Making the model ambiguity-aware), they say:

"Ambiguity is much rarer with multiple prompts and the three output masks will usually become similar. To minimize computation of degenerate losses at training and ensure the single unambiguous mask receives a regular gradient signal, we only predict a single mask when more than one prompt is given"

The 'single mask' they refer to is the one that gets used when multimask_output=False (meant for cases where more than 1 prompt is given) and is discarded otherwise.

bhack commented 1 week ago

Also in this V2 there is multimask_output_for_tracking.