lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

Extract the transformer intermediate layer #25

Open leolv131 opened 2 years ago

leolv131 commented 2 years ago

I want extract the transformer intermediate layer. I use follow code, but it does not work. nn.Sequential(*list(model.children()), how should i do?

Abeldewit commented 2 years ago

Hi, I was trying to accomplish the same thing using the same code. Instead I created a different class as an implementation of the Vit class which overwrites the forward pass to circumvent the last two layers that are used for classification.

import pytorch_pretrained_vit as ptv
from pytorch_pretrained_vit.model import PositionalEmbedding1D

class EncoderVit(ptv.ViT):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.positional_embedding = PositionalEmbedding1D(576, 768)

    def forward(self, x):
        x = self.patch_embedding(x).flatten(2).transpose(1, 2)
        # x = torch.cat((model.class_token.expand(1, -1, -1), x), dim=1)
        x = self.positional_embedding(x)
        x = self.transformer(x)
        return x

I needed this to use the ViT as an encoder, and I'm guessing you do too. Hope this helps!