facebookresearch / dino

PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO
Apache License 2.0
6.25k stars 905 forks source link

Question regarding patch retrieval demo #250

Closed crypdick closed 5 months ago

crypdick commented 1 year ago

I saw that you shared code for evaluating image-level retrieval (Issue & code). I am keen to replicate the patch-level retrieval from the official DINO demos here.

Thank you for your guidance! If there are any code snippets or resourced that you can share to replicate the demo, I would be most grateful.

Gugliebiagio commented 5 months ago

Hi, have you tried to implement that?

crypdick commented 5 months ago

@Gugliebiagio You can't use the forward outputs because it slices out just the global [CLS] embedding: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L214

Instead, you have to get the raw output using get_intermediate_layers:

import torch

model = torch.hub.load("facebookresearch/dino:main", "dino_vits16")
img = torch.randn(1, 3, 224, 224)
features = model(img)
print(features.shape)  # torch.Size([1, 384]) -- already pooled

features = model.get_intermediate_layers(img, len(model.blocks))
features = features[-1]  # residual stream @ final block
print([features.shape])  # torch.Size([1, 197, 384])

In this example, you have 1+14^2 embeddings, each which is 384-d. You can drop the first one, and map the 196 embeddings back to the 196 patches.