facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
46.7k stars 5.53k forks source link

Is there only one input shape for vit to correctly output mask? #696

Open w1005444804 opened 6 months ago

w1005444804 commented 6 months ago

If the input of VIt is not 1024x1024 but something else, such as 1024x512 or 768x512, can sam also accurately output the mask

heyoeyo commented 6 months ago

If you provide an image that isn't 1024x1024, then the model will scale your input so that the longest side is 1024 and then the shorter side is padded with zeros. So in other words, it just turns them into 1024x1024 images before any processing. There are a variety of image sizes/aspect ratios on the main SAM demo site, and most of the non-square images seem to work fine.

w1005444804 commented 6 months ago

@heyoeyo You're right. but I want to know if it's possible to choose different shapes for the input of VIT(encoder) .

heyoeyo commented 6 months ago

As-is, if you skip the built-in pre-processing and provide a different input size, then the image encoder will fail when trying to add the position embeddings, which have a 64x64 shape (the height & width of the patches after the patch embedding step using a 1024x1024 input).

The image encoder is a fairly simple transformer and can (in theory) deal with a wide range of input sizes, as long as you modify the positional encodings (which has to be done manually, since it's not in the original code base). As far as I know, the standard approach is to resize them to match the input size, for example dinov2 does this and I know the midas beit model does it as well. You'd have to change the position embedding step in the image encoder to something like:

if self.pos_embed is not None:
    _, h, w, _ = x.shape
    pos_embed = self.pos_embed.permute(0, 3, 1, 2)
    pos_embed = F.interpolate(pos_embed, size = (h,w), mode="bilinear")
    pos_embed = pos_embed.permute(0, 2, 3, 1)
    x = x + pos_embed

This should at least allow the ViT component to run at different sizes. However, the rest of SAM model (e.g. the pre-/post-processing, mask decoder and maybe the learned prompt embeddings) is also built around the 1024x1024 sizing, and would need to be modified to handle other sizes properly.

w1005444804 commented 6 months ago

@heyoeyo Thank you very much for your answer。Do you know how to make it run at FP16 precision , at same time forcing layernorm layers to run in FP32 precision ?

heyoeyo commented 6 months ago

Do you know how to make it run at FP16 precision , at same time forcing layernorm layers to run in FP32 precision ?

For the image encoder at least, the only use of layernorms is in the transformer blocks, as far as I can tell. Assuming the model is in float16 to begin with, a hacky approach would be to first cast the layernorms back to float32 at the start of the block forward function using something like:

# Switch norm layers to float32
orig_config = {"dtype": x.dtype}
f32_config = {"dtype": torch.float32}
self.norm1.to(**f32_config)
self.norm2.to(**f32_config)

And then switch the input data (i.e. x) to float32 for the norm layer calculation and then back again. You'd need to modify the norm1 and norm2 lines to something like:

x = self.norm1(x.to(**f32_config)).to(**orig_config)
...
x = x + self.mlp(self.norm2(x.to(**f32_config)).to(**orig_config))

However, while the conversion above would be doing the calculations in float32, the layernorm weights are being converted from f16 to f32. If you wanted to use the original f32 layernorm weights, you would need to re-load just the layernorm weights in f32, while leaving the rest of the model in f16. I'm not sure there is a nice way to do this... It might be possible by loading the original weights and removing all non-norm entries and then using the load_state_dict function with strict set to False, but I've never tried that, so I'm not sure if it works.

w1005444804 commented 6 months ago

@heyoeyo It seems that what you said is right,but I don't have much experience in using fp16 in PyTorch. I have converted the vit(encoder) to the onnx model. then infer with tensorrt, However it failed in fp16. because layernorm after self-attention in FP16 may cause overflow. so letting other layers run in fp16, but layernorm layer run in f32.that i think will be ok.

Matagi1996 commented 3 months ago

Hey, thank you very much for the tipp with Layernorms beeing the reason results are not good in fp16. I wrote a little script using forward hooks to convert between Fp32 and Fp16 just for the Layernorms (and the custom 2D layernorm) and it seems to produce good results.

sam = ... #load ckpt in Fp32
half = True
if half:
    def convert_to_fp32(module, input):
        return tuple(inp.float() for inp in input)

    def convert_to_fp16(module, input, output):
        return output.half()

    def apply_hooks(module:nn.Module):
        if isinstance(module, torch.nn.LayerNorm) or isinstance(module,LayerNorm2d):
            module.float()  # keep the LayerNorm layers in FP32
            module.register_forward_pre_hook(convert_to_fp32)
            module.register_forward_hook(convert_to_fp16)
        else:
            module.half()  # convert other layers to FP16

        # recursively apply hooks to child modules
        for child in module.children():
            apply_hooks(child)

    apply_hooks(sam.image_encoder)