lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
20.92k stars 3.08k forks source link

Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work) #221

Closed PrithivirajDamodaran closed 2 years ago

PrithivirajDamodaran commented 2 years ago

How can I access the last layer hidden states aka embeddings of an image from models like CrossViT and RegionViT? The extractor option works only on vanilla ViT.

Please advice

lucidrains commented 2 years ago

@PrithivirajDamodaran Hi Prithivida! Let me know if https://github.com/lucidrains/vit-pytorch/commit/4e62e5f05ee03ce46a5d1fe51b1d996c701786ec works now

lucidrains commented 2 years ago

regionvit can also work, if you pass in a reference to the layer whose output you would like to extract

import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(model, layer = model.layers[-1][-1])

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 512, 7, 7), (1, 512, 1, 1))
PrithivirajDamodaran commented 2 years ago

Thank you, will check and close. Big fan of your work.

PrithivirajDamodaran commented 2 years ago

Works fine! so just to be sure, the below tuple for a single image is

((1, 512, 7, 7) - last_layer emb (1, 512, 1, 1)) - CLS emb

That's a right understanding?

lucidrains commented 2 years ago

@PrithivirajDamodaran so RegionViT is a bit different than the conventional neural net in that it keeps two separate information paths and have them cross attend to each other iirc

what you are seeing is the outputs of those two separate paths, one is for the normal network output, the other is the "regional" tokens

lucidrains commented 2 years ago

@PrithivirajDamodaran if you are doing anything downstream i would concat those two together for a 1024 dimensional embedding

from einops import reduce
embedding = torch.cat((reduce(fine_embed, 'b c h w -> b c', 'mean'), reduce(region_embed, 'b c h w -> b c', 'mean')), dim = -1)
mathshangw commented 2 years ago

excuse me what if i need to remove the last layer of the layer for the classification to get the features before classifying it ?

mathshangw commented 2 years ago

is there any help please ?