facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.64k stars 5.63k forks source link

Using a mask prompt for boundary refinement #169

Open Gpoxolcku opened 1 year ago

Gpoxolcku commented 1 year ago

Hi, I have a roughly labeled dataset and trying to feed it's labels as a prompt into SAM. I want SAM to refine the segmentation labels and improve my dataset quality. In my case I don't use any additional prompt artifacts like points or boxes (though it works pretty good for such prompts). It seems to me that a pure mask prompt should be supported as well, according to the paper. But the results I obtain are kinda unreliable, an output mask mostly repeats an input one, even making it slightly worse. Is there a code snippet to build the prompts out of the foreign masks? Thanks in advance!

maoyj1998 commented 1 year ago

I encounter the same problem, I found the input mask size must be 256*256. When I resize my mask to this size, the output segmentation results is a mess and make no sense. Do anyone have a clue?

rmokady commented 1 year ago

Encountered the same issue

kampelmuehler commented 1 year ago

Can confirm that also in their colab if you use just the mask_input without any sparse guidance the results will not change.

If you want to do the same with a mask from another source you must at least first zero pad it to square dimensions and then resize to 256x256. Also it might need some sort of normalization to properly work.

Output of the first stage, obtained using a single guidance point: image

Output feeding the logits of the first stage as mask_input - without any additional queries: image

qraleq commented 1 year ago

I'm observing the same behavior. @Gpoxolcku Did you manage to figure it out? I would like to use SAM for exactly the same use case you're mentioning.

Gpoxolcku commented 1 year ago

I'm observing the same behavior. @Gpoxolcku Did you manage to figure it out? I would like to use SAM for exactly the same use case you're mentioning.

Not yet, unfortunately. The only crutch I came up with is sampling the points inside of an instance's mask (with different sampling strategies, e.g. w.r.t distance transform) as "positive" class and sampling more points outside of the mask as "negative" class. Then combining those sparse inputs along with a coarse binary mask into a prompt and feed it into the SAM. But still that's not perfect to refine the dataset

kampelmuehler commented 1 year ago

@Gpoxolcku Interesting that you had success with this - I found it to not really work well with many query points. How many did you use?

Davidyao99 commented 1 year ago

@kampelmuehler when you say pad the mask, is this so that the mask fits over the transformed input image to the model? (which is also padded and squared.)

tldrafael commented 1 year ago

Using only the mask logits did not work for me, it rendered nonsense results.

On the other hand, querying positive and negative points from the binary mask yielded a better result; and it improves a lot if you do it iteratively, feeding the predictor with random samples of positive and negative points of the binary mask besides the best logits outcome for some time (on the paper, they mention 11 iterations; see. Appendix $A - Training Algorithm);

The number of query points does not seem to matter too much; it seems to impact more on how fidelity you want it to keep with the original mask. In my data case, despite the iterative process improving the SAM outcome, it didn't refine fine details.

antoniocandito commented 1 year ago

I will be instrested to boundary refinement using a mask prompt.

First, I tried using a bbox and the object has been delineated with good accuracy. Second, I did convert the bbox to a binary mask, but model keep generating no contours for this prompt. I did resize the binary mask to 1x256x256.

Any help with this please?

kampelmuehler commented 1 year ago

@Davidyao99 yes, precisely

GoGoPen commented 1 year ago

The mask prompt and bbox prompt are needed to provided together to generate a proper mask. The mask should be rescaled by long size as the preprocessing of the image. Then the mask need to be padded to 256x256. Also note that the mask should be a binary mask.

cip8 commented 1 year ago

@antoniocandito , did you manage to make the mask input work?

@GoGoPen that's really useful information, thanks for sharing! I tried with binary masks, then I tried to convert their values to logits (similar as to what SAM returns) but with no success. Could you help us out with an approximate example on how to convert from binary to SAM-accepted format? πŸ’™πŸ––

markushunter commented 1 year ago

What is the proper way to pad the mask? Do you add the pad to the lower right, or do you center the mask in the target dimensions and add padding to the top, bottom, left, and right?

markushunter commented 1 year ago

Not sure if this is the proper way to get the mask output, but this is what I discovered...

SamPredictor.predict() states that the mask_input should be something like this:

            mask_input (np.ndarray): A low resolution mask input to the model, typically
            coming from a previous prediction iteration. Has form 1xHxW, where
            for SAM, H=W=256.

Looking at a histogram of values the model produces in the low_res_masks_np output of predict(), the values are not a boolean mask. The values are floats.

image

In Sam.py, the mask_threshold is harcoded to 0.0. Thresholding the low_res_masks_np output with 0 showed an ok mask.

Looking at a thresholded and scaled * 128 version of low_res_masks_np, the model applies padding to the mask to the bottom and right only.

By making a custom mask_input for SamPredictor.predict() where negative locations are -8 and positive locations are 1 with -8 padding on the bottom and right of the mask as necessary, subsequent reruns of segment anything produced a mask. However, the mask still wasn't perfect.

cip8 commented 1 year ago

Very good observations @markushunter! I'll try to add padding to the bottom-right and see if the results change.

I don't fully understand the part related to assigning values of -8 / 1. Do you mean that binary mask values (0/1) should be replaced to -8 and 1, because of the 0.0 mask_threshold in Sam.py?

Thanks for the info! πŸ’™πŸ––

markushunter commented 1 year ago

@cip8 Yes, instead of using 0 or 1 for the values in the mask, you need to represent the negative space with a number far less than zero. Since SAM thresholds the mask at the floating point value 0.0, having the negative space as 0.0 isn't good enough.

The histogram seemed to imply that negative space in the output mask has values around -8 to -10, so I just ran with -8.

cip8 commented 1 year ago

The docs say that logits from a previous run can be used for this mask_input.

These logits are indeed floats and look like this:

(1, 256, 256) / float32 
[[[-11.90418   -12.534466  -13.846361  ... -19.109943  -19.418356  -18.89853  ]
  [-12.359286  -15.481771  -14.45459   ... -20.847857  -19.149311  -20.422709 ]
  [-11.877727  -13.034173  -13.75271   ... -18.832436  -20.500711  -19.762798 ]
  ...
  [ -2.2596788   7.5174465   6.552994  ... -12.575503  -11.790027  -11.399892 ]
  [ -2.4307566   8.219517    6.204507  ... -10.520343  -11.538195  -10.084738 ]
  [ -2.2835727   5.5459557   5.4847026 ... -11.283685  -11.467551  -9.843957 ]]]

From what I understand they represent probabilities for the mask, do you know if that's accurate?

cip8 commented 1 year ago

Grayscale mask to SAM mask_input:

Based on the info discussed so far, this is how I implemented a conversion between grayscale and SAM's mask_input in my code:

class Segmentix:
[...]
    def resize_mask(
        self, ref_mask: np.ndarray, longest_side: int = 256
    ) -> tuple[np.ndarray, int, int]:
        """
        Resize an image to have its longest side equal to the specified value.

        Args:
            ref_mask (np.ndarray): The image to be resized.
            longest_side (int, optional): The length of the longest side after resizing. Default is 256.

        Returns:
            tuple[np.ndarray, int, int]: The resized image and its new height and width.
        """
        height, width = ref_mask.shape[:2]
        if height > width:
            new_height = longest_side
            new_width = int(width * (new_height / height))
        else:
            new_width = longest_side
            new_height = int(height * (new_width / width))

        return (
            cv2.resize(
                ref_mask, (new_width, new_height), interpolation=cv2.INTER_NEAREST
            ),
            new_height,
            new_width,
        )

    def pad_mask(
        self,
        ref_mask: np.ndarray,
        new_height: int,
        new_width: int,
        pad_all_sides: bool = False,
    ) -> np.ndarray:
        """
        Add padding to an image to make it square.

        Args:
            ref_mask (np.ndarray): The image to be padded.
            new_height (int): The height of the image after resizing.
            new_width (int): The width of the image after resizing.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The padded image.
        """
        pad_height = 256 - new_height
        pad_width = 256 - new_width
        if pad_all_sides:
            padding = (
                (pad_height // 2, pad_height - pad_height // 2),
                (pad_width // 2, pad_width - pad_width // 2),
            )
        else:
            padding = ((0, pad_height), (0, pad_width))

        # Padding value defaults to '0' when the `np.pad`` mode is set to 'constant'.
        return np.pad(ref_mask, padding, mode="constant")

    def reference_to_sam_mask(
        self, ref_mask: np.ndarray, threshold: int = 127, pad_all_sides: bool = False
    ) -> np.ndarray:
        """
        Convert a grayscale mask to a binary mask, resize it to have its longest side equal to 256, and add padding to make it square.

        Args:
            ref_mask (np.ndarray): The grayscale mask to be processed.
            threshold (int, optional): The threshold value for the binarization. Default is 127.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The processed binary mask.
        """

        # Convert a grayscale mask to a binary mask.
        # Values over the threshold are set to 1, values below are set to -1.
        ref_mask = np.clip((ref_mask > threshold) * 2 - 1, -1, 1)

        # Resize to have the longest side 256.
        resized_mask, new_height, new_width = self.resize_mask(ref_mask)

        # Add padding to make it square.
        square_mask = self.pad_mask(resized_mask, new_height, new_width, pad_all_sides)

        # Expand SAM mask's dimensions to 1xHxW (1x256x256).
        return np.expand_dims(sam_mask, axis=0)

Usage example:

# Convert reference mask to SAM format & run predictor.
sam_mask: np.ndarray = self.reference_to_sam_mask(reference_mask)
masks, scores, logits = predictor.predict(
    multimask_output=False,
    box=np.array(ref_bbox),
    mask_input=sam_mask,
)

Experimental findings

Improvement suggestions

AkshitSharma1 commented 1 year ago

Hi all,

I am trying to refine cell segmentation foreground/background mask predicted by another model using SAM. I have tried following iterative approach and Grayscale to mask_input approach (as mentioned by @cip8 sir) but no help. Please could someone guide me? All my images are greyscale of size (256,256)

qraleq commented 1 year ago

Hi all, has anyone managed to solve this problem efficiently? I see marginal improvements using the approach @cip8 proposed.

danigarciaoca commented 1 year ago

Hi everyone,

If you check this demo notebook, it is explained that the input mask is not such a mask: it is supposed to be the output low resolution mask from a previous iteration (prediction):

If available, a mask from a previous iteration can also be supplied to the model to aid in prediction

So for now, it seems that is no possible to prompt with an accurate mask (or not with good results).

Hope it helps!

cip8 commented 1 year ago

If you check this demo notebook, it is explained that the input mask is not such a mask: it is supposed to be the output low resolution mask from a previous iteration (prediction):

Is there a law that says we are not allowed to "fake" these logits? πŸ˜ƒ

So far in this conversation people came with different conclusions on how to replicate the behavior of these masks, where the threshold point is, etc. I don't think an answer that doesn't take into consideration the rest of the thread is helpful. Anyone can say "this can't be done", but that's not a real hacker mentality and rarely achieves anything.

cip8 commented 1 year ago

Hi all, has anyone managed to solve this problem efficiently? I see marginal improvements using the approach @cip8 proposed.

I think the problem resides in the "weight" associated to this extra mask parameter. My intuition is that the model doesn't put a lot of importance on it, being instead focused on points and border boxes. Because even when you supply them with "clean" logits from another run, the results seem to change only marginally.

Maybe the next version will put a bigger importance on this param, and maybe accept sizes greater than 256x256 - this will make the model easier to include in existing image processing pipelines.

As a trick to bypass this I extract a grid of points from the mask and pass it to SAM instead - the results are much better than the minor changes provided by using the mask_input.

I wish someone from Meta could clarify this for us πŸ™

πŸ’™πŸ––

danigarciaoca commented 1 year ago

@cip8 apologize if my response was not to your liking or not what you were looking for.

I you had taken the time to deeply read the paper and replicate SAM architecture (not just reading the docs...) you would understand the purpose of this mask_input better.

Of course it is possible to replicate it, is just coding and imitating. My point was about replicating it with the desired results. Once again, if you read this thread with all its comments you can check that nobody has gotten the "refinement" results that everyone (including me) were expecting. This is because mask_input is expected to be used in conjunction with a point prompt input (or box), not alone by itself.

PS: if it was so straightforward, Meta would have released it for mask prompting...

cip8 commented 1 year ago

@cip8 apologize if my response was not to your liking or not what you were looking for.

It's not about that @dankresio - every contribution is of course helpful and I appreciate your reply, truly! It just seemed to me that your answer didn't take into consideration what was discussed before & I'm also quite easily-triggered by "can't be done" type of answers πŸ˜… I also apologize for my harsh reply πŸ’™πŸ––

shelper commented 1 year ago

i wonder that in stead of "extracting a grid of points from the mask and pass it to SAM", if we shrink the mask prompt with a certain pixel number (to avoid the sampled points later being out of the ground truth mask), and sample a few points on the edge of the shrinked mask would provide better results. To me, it may constrains SAM in a way similar to a mask promt.

I might test this and report back if i do that. Others may update if anyone here has time to try it out

jerome223 commented 1 year ago

It seems that self-made masks to logit was implemented in micro-sam for ellipse and polygonal prompts and seems to be working correctly. Function : def _compute_logits_from_mask(mask, eps=1e-3):

Link : https://github.com/computational-cell-analytics/micro-sam/blob/83997ff4a471cd2159fda4e26d1445f3be79eb08/micro_sam/prompt_based_segmentation.py#L375-L388

m13uz commented 10 months ago

You guys might want to check out these two repositories and try creating some sort of a pipeline stitching everything together

https://github.com/danielgatis/rembg/tree/main https://github.com/hkchengrex/CascadePSP

voplica-git commented 6 months ago

Does anyone know how to convert ultralytics (YOLO) masks to the input_mask for SAM? I tried both above approaches (reusing _compute_logits_from_mask and reusing reference_to_sam_mask) but seems those are for different mask types. Please, if anyone has any clue here, share it and let me buy you a coffee πŸ™