ml-jku / MIM-Refiner

A Contrastive Learning Boost from Intermediate Pre-Trained Representations
MIT License
34 stars 3 forks source link

Extracting Features #12

Closed ShadeAlsha closed 2 days ago

ShadeAlsha commented 3 days ago

Hi,

I’m using the backbone model to extract features for clustering tasks, and I've observed that the output features have a shape of [257, 1280]. I am wondering if you used all of the features for clustering or just the class token.

Thanks!

BenediktAlkin commented 2 days ago

A H14 model has 256 patches for a 224x224 image + 1 cls token the first entry is the CLS token, the following 256 are the patches, so you would use

model = ...
image = torch.randn(1, 3, 224, 224)
model_out = model(image) # [257, 1280] for ViT-H/14; [197, 1024] for ViT-L/16
# exctract features for classification
cls_token = model_out[:, 0]
# extract features for segmentation
patches = model_out[:, 1:]