facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
46.99k stars 5.56k forks source link

Extraction of image embeddings/ feature vectors latent space. #673

Open BilAlHomsi opened 8 months ago

BilAlHomsi commented 8 months ago

Hello, I'm trying to understand how SAM works. I am interested in extracting the image embeddings created by ImageEncoderViT. Also, I'm interested in the output after combining image embeddings and masks through the Conv-operation (before mask decoding).

I look forward to your help.

Best regards

heyoeyo commented 8 months ago

You can get the image embeddings off of the predictor after calling the .set_image(...) function using: predictor.features There's a more detailed explanation in issue #665

Getting the combined image embedding + mask embedding is a bit trickier. The Conv-operation you're referring to seems to happen in the prompt encoder, but the image embeddings & masks get combined in the mask decoder, so the result isn't easily accessible. Following the code from #665, you could probably do something like:

# ... load model/image and do all setup stuff

# Get embeddings by running model components
image_embeddings = predictor.model.image_encoder(input_image_as_tensor)
point_embeddings, mask_embeddings = predictor.model.prompt_encoder(...)

# Run the part of the mask decoder that adds the image+mask embeddings
# Everything after this point is just copied/modified from:
# https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120C1-L127C44
iou_weight = predictor.model.mask_decoder.iou_token.weight
mask_weight = predictor.model.mask_decoder.mask_tokens.weight

# Concatenate output tokens
output_tokens = torch.cat([iou_weight, mask_weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(point_embeddings.size(0), -1, -1)

# Expand per-image data in batch direction to be per-mask
tokens = torch.cat((output_tokens, point_embeddings), dim=1)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + mask_embeddings

The final src result is the image embeddings + mask embeddings after the downscaling conv step. In the original code, the src result also get positional encodings added to it right after the mask embedding, which isn't included in the code above.

BilAlHomsi commented 8 months ago

Thanks @heyoeyo for your explanation and your support.

Which transfer parameters does the method receive? In your opinion, it also makes sense to use positional encoding in this case. If so, how can I integrate it into the code. Does the combination of masks and image embeddings also provide added value or useful information?

Attached you will find a code section of my work:

image = cv2.imread('image1.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sys.path.append("..")

sam_checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)

with torch.inference_mode():  # torch.no_grad()
    image_embeddings = predictor.features

    msk, iou_, _ = predictor.predict_torch(
    point_coords=None, point_labels=None)

transformed_image = predictor.transform.apply_image(image)
transformed_image = torch.as_tensor(
    transformed_image, device=predictor.device)
# Add an extra dimension at the beginning
transformed_image = transformed_image.unsqueeze(0)
transformed_image = transformed_image.permute(0, 3, 1, 2).contiguous()
preproc_img = predictor.model.preprocess(transformed_image)

image_tensor = preproc_img

with torch.inference_mode():  # torch.no_grad()
    features_new = predictor.model.image_encoder(image_tensor)

point_box_embeddings, mask_embeddings = predictor.model.prompt_encoder(points=None, boxes=None, masks=None)

iou_weight = predictor.model.mask_decoder.iou_token.weight
mask_weight = predictor.model.mask_decoder.mask_tokens.weight

# Concatenate output tokens
output_tokens = torch.cat([iou_weight, mask_weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(point_box_embeddings.size(0), -1, -1)

# Expand per-image data in batch direction to be per-mask
tokens = torch.cat((output_tokens, point_box_embeddings), dim=1)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + mask_embeddings
heyoeyo commented 8 months ago

how can I integrate it (positional encoding) into the code.

Actually I was wrong about them being added right after, they are only generated right after and then used within the mask transformer. If you'd like to generate them, you can do something similar to the code in the mask decoder:

# Image positional encodings
image_pe = predictor.model.prompt_encoder.get_dense_pe()
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)

As for actually making use of the positional encodings, that happens in a much more complicated way where they affect every layer inside the transformer of the mask decoder.

Does the combination of masks and image embeddings also provide added value or useful information

This will depend on what you want to do with the data I guess. The image embeddings are probably quite useful as a general representation of the original input image. I'm sure they could be re-used with other models to do things like object detection or more general semantic segmentation. Adding the mask embedding to the image embedding probably makes it less useful in a general sense, but more useful for things similar to what the SAM mask decoder is already doing. So if your goal is to modify/alter the original SAM behavior, then it may be better to use the combined image + mask embedding.

BilAlHomsi commented 8 months ago

Hello @heyoeyo, thanks for your support.

I've already figured it out and integrated it into my code.

A previous question about the prompt encoder parameters, if the points, boxes, and masks are set to None, is the image considered as a whole? If not, how are dense_embeddings (mask_embeddings) and sparse_embeddings (point_embeddings) effected?

Another question, If I detect all masks from an image and integrate them into the original image and then encode it with SAM, are the resulting embeddings the same as the combined embeddings of the image and the masks if I encoded them separately and then extracted their embeddings?

Thanks a lot!

heyoeyo commented 8 months ago

if the points, boxes, and masks are set to None ...

If no prompt is provided, then the sparse/point embedding will be a zero-dimensional tensor, so it won't have any influence. The dense/mask embedding will default to using the no_mask_embed learned embedding. From this, I would assume the segmentation result to be fairly random, since there's no target being specified and (as far as I know) the model wasn't trained to operate this way. Though I haven't tried using it this way myself, maybe it does something more interesting?

If I detect all masks from an image and integrate them into the original image

I'm not sure if I understand what you mean by integrating the masks into the image. If you mean masking out parts of the image and passing that through the model, I would expect the embeddings to be different generally. As far as I know, the mask input into the mask decoder is not meant to be binary, it's supposed to take on a range of positive & negative values (similar to the raw output of the SAM model). So just from that alone, since masking an image requires a binary mask, but the mask decoder takes non-binary masks, I would assume the masked image embedding vs. image embedding + embedded mask to be different.

tcourat commented 7 months ago

Hi, you can keep the embeddings during the inference by using forwards hooks.

For instance if you want to store the image encoder features while running the mask generator :

img = None # your img here

sam_checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to("cuda")
sam.eval()
mask_generator = SamAutomaticMaskGenerator(sam, points_per_side = 64, pred_iou_thresh = 0.5, stability_score_thresh = 0.8)

features = {}
def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook
mask_generator.predictor.model.image_encoder.neck[3].register_forward_hook(get_features('features'))
masks  = mask_generator.generate(img)

Then you will have access to the masks and also the features from the image encoder in the dict features. You can do similarly for masks embedding I guess.

BilAlHomsi commented 7 months ago

Hey,

I have another question, when I pass masks to the call, I get embeddings (mask_embeddings) of different form. This refers to the height and width. Is it possible to adjust the SAM code so that the shape becomes uniform without affecting the quality of the mask embeddings?

Thanks a lot for any Help

sam = sam_model_registrymodel_type sam.to(device=device)

'Code:'

with torch.inference_mode(): # torch.no_grad() mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image)

predictor = SamPredictor(sam)
predictor.set_image(image)

image_embeddings = predictor.get_image_embedding()  # == predictor.features

segmentation = np.array(masks[0]['segmentation'])
segmentation = segmentation.astype(np.float32)
segmentation = np.expand_dims(segmentation, axis=0)  # Add channel dimension
segmentation = np.expand_dims(segmentation, axis=0)  # Add batch dimension
mask_tensor = torch.as_tensor(segmentation, device=predictor.device)

# Get the generated mask embeddings from SAM
_, mask_embeddings = predictor.model.prompt_encoder(points=None, boxes=None, masks=mask_tensor)
heyoeyo commented 7 months ago

If you mean having the output mask_embeddings match the size of the input mask_tensor, then one way would be to scale up by a factor of 4 with something like:

upscaled_mask_embeddings = torch.nn.functional.interpolate(mask_embeddings, scale_factor=4)

If you wanted to change the code so that the result is generated at a higher resolution directly, then I think you could remove the stride on the convolution layers which downscale the embeddings. You might also need to add padding (something like padding="same") to avoid losing 'pixels' around the edges.

BilAlHomsi commented 7 months ago

Hey @heyoeyo,

thanks for your support.

The embeddings shape without providing masks is --> torch.Size([1, 256, 64, 64])

Get the generated mask embeddings from SAM

_, mask_embeddings = predictor.model.prompt_encoder(points=None, boxes=None, masks=None)

The embeddings shape with providing masks is something like --> torch.Size([1, 256, 303, 386]), that shape depends on the input image.

mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image)

segmentation = np.array(masks[0]['segmentation']) segmentation = segmentation.astype(np.float32) segmentation = np.expand_dims(segmentation, axis=0) # Add channel dimension segmentation = np.expand_dims(segmentation, axis=0) # Add batch dimension mask_tensor = torch.as_tensor(segmentation, device=predictor.device)

Get the generated mask embeddings from SAM

_, mask_embeddings = predictor.model.prompt_encoder(points=None, boxes=None, masks=mask_tensor)

What is the reason behind the different shapes and how to get the shape --> torch.Size([1, 256, 64, 64]) with providing masks. Do the mask embeddings have positional encodings? If they have, how to access them.

Thanks again for your help

heyoeyo commented 7 months ago

What is the reason behind the different shapes and how to get the shape --> torch.Size([1, 256, 64, 64])

The 64x64 shape is the size of the output of the SAM image encoder (by default), which assumes a 1024x1024px input image and uses 16px patches, so 1024/16 = 64 patches in x & y. The image encoder also always outputs 256 channels, regardless of the SAM model size.

The mask input of the prompt encoder expects a mask which is 256x256px. The paper describes this on pg. 16 under the Prompt Encoder section. The prompt encoder downscales this mask by a factor of 4 to match the 64x64 sizing while also creating the 256 channel size, so that the mask embedding can be added to the image embedding later on.

Do the mask embeddings have positional encodings?

No, if a mask is provided as input, the only thing the prompt encoder does is that 4x downscale/convolution.

BilAlHomsi commented 7 months ago

Hey @heyoeyo,

Thanks for your effort, but something is wrong. Even after downscaling, the shape of the mask embeddings with providing masks will not have the same shape as the shape of the image embeddings of (1, 256, 64, 64)

What is the reason behind the mask shape of torch.Size([1, 256, 303, 386])

How can the mask embeddings be reduced to (1, 256, 64, 64) with the provision of masks,

Many thanks

heyoeyo commented 7 months ago

If the input mask isn't 256x256, then the embedding result won't be 64x64, since it only does a 4x downscale (i.e. it isn't scaling to 64x64 directly).

You could either scale the input mask to be 256x256, so that it gets downscaled to 64x64 or otherwise use the original input sizing and instead scale the embedding result to be 64x64. In both cases, the same torch.nn.functional.interpolate function from above can be used, but instead of using the scale_factor arg you'd want to use the size arg to set the exact sizing to scale to.