Closed crypdick closed 5 months ago
Hi, have you tried to implement that?
@Gugliebiagio You can't use the forward
outputs because it slices out just the global [CLS] embedding: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L214
Instead, you have to get the raw output using get_intermediate_layers
:
import torch
model = torch.hub.load("facebookresearch/dino:main", "dino_vits16")
img = torch.randn(1, 3, 224, 224)
features = model(img)
print(features.shape) # torch.Size([1, 384]) -- already pooled
features = model.get_intermediate_layers(img, len(model.blocks))
features = features[-1] # residual stream @ final block
print([features.shape]) # torch.Size([1, 197, 384])
In this example, you have 1+14^2 embeddings, each which is 384-d. You can drop the first one, and map the 196 embeddings back to the 196 patches.
I saw that you shared code for evaluating image-level retrieval (Issue & code). I am keen to replicate the patch-level retrieval from the official DINO demos here.
Thank you for your guidance! If there are any code snippets or resourced that you can share to replicate the demo, I would be most grateful.