huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
30.84k stars 4.65k forks source link

[FEATURE] Vision Transformer Feature Extraction #657

Closed isaaccorley closed 3 months ago

isaaccorley commented 3 years ago

It seems that VisionTransformer doesn't support feature extraction of all outputs in the forward_features method. Only returning of the cls token or [cls_token, distillation_token] is available timm/models/vision_transformer.py#L291-L304. This functionality seems particularly useful similar to how pretrained ResNets features are commonly used for downstream tasks.

  def forward_features(self, x):
      x = self.patch_embed(x)
      cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
      if self.dist_token is None:
          x = torch.cat((cls_token, x), dim=1)
      else:
          x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
      x = self.pos_drop(x + self.pos_embed)
      x = self.blocks(x)
      x = self.norm(x)
      if self.dist_token is None:
          return self.pre_logits(x[:, 0])
      else:
          return x[:, 0], x[:, 1]

However this is available for other models e.g.

m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')

I implemented something similar for a side project here that required all ViT outputs for some a downstream segmentation task. I simply override the method in my example, but I assume some attribute could be added to VisionTransformer to allow for returning 'unpooled' output. Maybe something like this:

  def forward_features(self, x):
      x = self.patch_embed(x)
      cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
      if self.dist_token is None:
          x = torch.cat((cls_token, x), dim=1)
      else:
          x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
      x = self.pos_drop(x + self.pos_embed)
      x = self.blocks(x)

      if self.unpooled:
        if self.dist_token is None:
            return x[:, 1:]
        else:
            return x[:, 2:]

      else:
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

  def forward(self, x):
      x = self.forward_features(x)
      if not self.unpooled:
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])  # x must be a tuple
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
      return x
rwightman commented 3 years ago

@isaaccorley already on the radar but haven't had a chance to come up with a design yet. See #617. I was going to rename the title of that but I'll keep this one open instead and close 617.

One complication is that new vision transformer models have been coming in at a rapid rate. Any solution has to work for all of them, whether they have spatial features or flattened, and whether the user wants them to be unflatten if they are flat and there is enough info to unflatten, etc etc.

I would also like to do it without significant modifications and new conditionals any forward fns, but just adding extra metadata in a feature_info attribute as it becomes specific to each model. This is what I did for CNNs here and wrote FeatureExtractor modules to work with the metadata.

There is also the consideration of tokens, to include them where possible (if they exist) or not #655

And finally, the torchvision team let me know they're working on some new funcitonality using torch.fx that would allow writing more general feature extraction. I was thinking of exploring that as an possible approach...

isaaccorley commented 3 years ago

@rwightman After reviewing some of the other transformer variant code I understand what you mean. I have some additional bandwidth to help out if needed when you want to tackle this. I think in the meantime a FeatureExtractor module but for vision transformers would work nicely but torch.fx also seems like it has the ability to solve this exact problem.

rwightman commented 3 years ago

@isaaccorley some help could be useful here, I need to ask torchvision team and figure out what the fx solution might look like, whether it has a chance of working well with torchscript, and then make decision on putting in the effort there or making a vision transformer / vision mlp specific feature extraction module

JarvisKevin commented 2 years ago

The same problem present in other models, such as 'mobilenetv3_large_100'. The global_pool operator is embedded in the 'forward_features' function which is inconsistent with the tutorial in [https://rwightman.github.io/pytorch-image-models/feature_extraction/] @rwightman

rwightman commented 2 years ago

@JarvisKevin mnv3 is a completely different situation, if you read the comments in that model definition, that architecture, by design does not match other convnets due to its 'efficient head' where global avg pooling is moved in front of the penultimate feature layer, the current choice for foward_features vs features was made to balance simplicity and compatibility with as many of the feature extraction use cases a possible.

Re ViT... I have in-progress refactoring / feature changes for vit/mlp models in the works. It's pretty significant and relies on some helpers that won't be in torchvision until 1.10 pytorch release so I'm sitting on them a while and making sure they are the right approach..

caihaunqai commented 2 years ago

Hello, I find un_pooled part before x = self.pos_drop(x + self.pos_embed) in some other codes, while the aboves after them. I wonder which method works better.

oguz-hanoglu commented 9 months ago

I believe this issue should be considered resolved now that forward_features has been implemented. (https://github.com/huggingface/pytorch-image-models/issues/1029)