facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.37k stars 5.61k forks source link

Checkered Predictions (Checkered Artifacts?) #778

Open Vishawjeet-rmsl opened 2 weeks ago

Vishawjeet-rmsl commented 2 weeks ago

Hi, I have been modifying the scripts of vanilla SAM, mainly to come up with my own training script. I was kind of successful in that, and training is happening with loss gradually reducing. But I noticed something, when I save the predictions made by the model in every epoch, I observe that there is a checkered lines all over the predictions. For eg, in the below image, the left one is prediction from epoch 1 and right image is from epoch 117. I observe that although the grid is fading, but it's clearly visible. image

Does anyone know what is causing this? Or is it just because the model is not trained for enough number of epochs? Well, I'm using just 15 image and mask pairs for training (Using image encoder weights, but training prompt encoder and mask decoder from scratch).

I would be grateful if someone can give me some clue if not a proper solution. Thanks in advance!

heyoeyo commented 2 weeks ago

The image you posted seems to have about 64 'tiles' horizontally. By default the image encoder outputs a 64x64 token 'image' that is processed and eventually upscaled by the mask decoder. So it seems likely that the upscaling isn't 'mixing' the pixels together enough and therefore the original 64x64 grid is still visible in the result.

There can also be other, much more subtle, lower-resolution grid artifacts that can appear due to the windowing the models uses.

Vishawjeet-rmsl commented 1 week ago

The image you posted seems to have about 64 'tiles' horizontally. By default the image encoder outputs a 64x64 token 'image' that is processed and eventually upscaled by the mask decoder. So it seems likely that the upscaling isn't 'mixing' the pixels together enough and therefore the original 64x64 grid is still visible in the result.

There can also be other, much more subtle, lower-resolution grid artifacts that can appear due to the windowing the models uses.

Wow! It does seem to have 64 tiles, if this is the reason then it could mean there is some issue with the upsampling method. BTW, the above images are raw predictions (Upscaled to the original image size) without applying the threshold. So, I was wondering maybe another reason for this could be the interpolation? i.e. the raw prediction from the mask decoder has spatial dimension 256x256 and then we interpolate it during the postprocessing. Maybe I should visualize the raw predictions as well.

heyoeyo commented 1 week ago

So, I was wondering maybe another reason for this could be the interpolation?

Since the interpolation is bilinear, it probably shouldn't introduce any artifacts other than a blurring effect as the smaller pattern is scaled up.

However, the fact that the model doesn't upscale all the way back to the original input size may be part of the problem (there was some discussion of this on the samv2 issue board), since it gives the model less chance of processing the original tokens + any artifacts get interpolated up to be more visible.

bhack commented 1 week ago

I've tried to upscale the decoder more smoothly with some extra layers (512 and 1024) up to 1024x1024 instead of the original 256x256 + pure interpolation and I have seen similar artifacts.

I think that there is still something else that is going to impact the resolution.

heyoeyo commented 1 week ago

I've tried to upscale the decoder more smoothly with some extra layers (512 and 1024) up to 1024x1024 instead of the original 256x256 + pure interpolation and I have seen similar artifacts

That's interesting! Maybe the decoder model is just too small/simple to avoid these kinds of artifacts entirely. It's probably hard to improve it without breaking the original 'real-time on cpu' design constraint. Maybe a few regular convolutions in between the upscaling steps could help blend things better spatially?

bhack commented 1 week ago

Maybe a few regular convolutions in between the upscaling steps could help blend things better spatially

It is what I have tried to not invalidate the pretrained checkpoint part of the decoder.

Probably we need to have a better design of these extra layers.

If not you are going to strictly interpolate from 256x256 (for 1024x1024 inputs) at:

https://github.com/facebookresearch/sam2/blob/main/sam2%2Fmodeling%2Fsam2_base.py#L373