NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
9.16k stars 1.42k forks source link

Visualizing SAM Mask decoder #385

Open nahidalam opened 8 months ago

nahidalam commented 8 months ago

I was looking into your tutorial on visualizing self-attention DINO.

Planning to do similar for visualizing the attention heads of the mask decoders for SAM. Based on SAM paper and SAM codebase, there should be 8 attention heads. But below code shows me there are only 1 attention head

from transformers import ViTFeatureExtractor
from transformers import SamModel, SamProcessor
from transformers import SamVisionConfig, SamConfig, SamPromptEncoderConfig, SamMaskDecoderConfig

# image feature extraction
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/sam-vit-base", size = 1024)
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values

# define SAM configs
vision_config = SamVisionConfig(patch_size = 16)
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig()
samconfig = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)

# define model
model = SamModel.from_pretrained("facebook/sam-vit-base", config = samconfig)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# forward pass
outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True, return_dict=True)
Screenshot 2024-01-25 at 9 17 51 PM

The 2nd dimension of the above tensor should be 8 not 1.

Am I missing something?